Coverage for geometric_kernels/lab_extras/extras.py: 99%

97 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import lab as B 

2from beartype.typing import List 

3from lab import dispatch 

4from lab.util import abstract 

5from plum import Union 

6from scipy.sparse import spmatrix 

7 

8 

9@dispatch 

10@abstract() 

11def take_along_axis(a: B.Numeric, index: B.Numeric, axis: int = 0): 

12 """ 

13 Gathers elements of `a` along `axis` at `index` locations. 

14 

15 :param a: 

16 Array of any backend, as in `numpy.take_along_axis`. 

17 :param index: 

18 Array of any backend, as in `numpy.take_along_axis`. 

19 :param axis: 

20 As in `numpy.take_along_axis`. 

21 """ 

22 

23 

24@dispatch 

25@abstract() 

26def from_numpy(_: B.Numeric, b: Union[List, B.Numeric]): 

27 """ 

28 Converts the array `b` to a tensor of the same backend as `_`. 

29 

30 :param _: 

31 Array of any backend used to determine the backend. 

32 :param b: 

33 Array of any backend or list to be converted to the backend of _. 

34 """ 

35 

36 

37@dispatch 

38@abstract() 

39def trapz(y: B.Numeric, x: B.Numeric, dx: B.Numeric = 1.0, axis: int = -1): 

40 """ 

41 Integrate along the given axis using the trapezoidal rule. 

42 

43 :param y: 

44 Array of any backend, as in `numpy.trapz`. 

45 :param x: 

46 Array of any backend, as in `numpy.trapz`. 

47 :param dx: 

48 Array of any backend, as in `numpy.trapz`. 

49 :param axis: 

50 As in `numpy.trapz`. 

51 """ 

52 

53 

54@dispatch 

55@abstract() 

56def logspace(start: B.Numeric, stop: B.Numeric, num: int = 50): 

57 """ 

58 Return numbers spaced evenly on a log scale. 

59 

60 :param start: 

61 Array of any backend, as in `numpy.logspace`. 

62 :param stop: 

63 Array of any backend, as in `numpy.logspace`. 

64 :param num: 

65 As in `numpy.logspace`. 

66 """ 

67 

68 

69def cosh(x: B.Numeric) -> B.Numeric: 

70 r""" 

71 Compute hyperbolic cosine using the formula 

72 

73 .. math:: \textrm{cosh}(x) = \frac{\exp(x) + \exp(-x)}{2}. 

74 

75 :param x: 

76 Array of any backend. 

77 """ 

78 return 0.5 * (B.exp(x) + B.exp(-x)) 

79 

80 

81def sinh(x: B.Numeric) -> B.Numeric: 

82 r""" 

83 Compute hyperbolic sine using the formula 

84 

85 .. math:: \textrm{sinh}(x) = \frac{\exp(x) - \exp(-x)}{2}. 

86 

87 :param x: 

88 Array of any backend. 

89 """ 

90 return 0.5 * (B.exp(x) - B.exp(-x)) 

91 

92 

93@dispatch 

94@abstract() 

95def degree(a): 

96 """ 

97 Given an adjacency matrix `a`, return a diagonal matrix 

98 with the col-sums of `a` as main diagonal - this is the 

99 degree matrix representing the number of nodes each node 

100 is connected to. 

101 

102 :param a: 

103 Array of any backend or `scipy.sparse` array. 

104 """ 

105 

106 

107@dispatch 

108@abstract() 

109def eigenpairs(L, k: int): 

110 """ 

111 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues 

112 of a symmetric positive semi-definite matrix `L`. 

113 

114 :param a: 

115 Array of any backend or `scipy.sparse` array. 

116 :param k: 

117 The number of eigenpairs to compute. 

118 """ 

119 

120 

121@dispatch 

122@abstract() 

123def set_value(a, index: int, value: float): 

124 """ 

125 Set a[index] = value. 

126 This operation is not done in place and a new array is returned. 

127 

128 :param a: 

129 Array of any backend or `scipy.sparse` array. 

130 :param index: 

131 The index. 

132 :param value: 

133 The value to set at the given index. 

134 """ 

135 

136 

137@dispatch 

138@abstract() 

139def dtype_double(reference: B.RandomState): 

140 """ 

141 Return `double` dtype of a backend based on the reference. 

142 

143 :param reference: 

144 A random state to infer the backend from. 

145 """ 

146 

147 

148@dispatch 

149@abstract() 

150def float_like(reference: B.Numeric): 

151 """ 

152 Return the type of the reference if it is a floating point type. 

153 Otherwise return `double` dtype of a backend based on the reference. 

154 

155 :param reference: 

156 Array of any backend. 

157 """ 

158 

159 

160@dispatch 

161@abstract() 

162def dtype_integer(reference: B.RandomState): 

163 """ 

164 Return `int` dtype of a backend based on the reference. 

165 

166 :param reference: 

167 A random state to infer the backend from. 

168 """ 

169 

170 

171@dispatch 

172@abstract() 

173def int_like(reference: B.Numeric): 

174 """ 

175 Return the type of the reference if it is integer type. 

176 Otherwise return `int32` dtype of a backend based on the reference. 

177 

178 :param reference: 

179 Array of any backend. 

180 """ 

181 

182 

183@dispatch 

184@abstract() 

185def get_random_state(key: B.RandomState): 

186 """ 

187 Return the random state of a random generator. 

188 

189 :param key: 

190 The random generator. 

191 """ 

192 

193 

194@dispatch 

195@abstract() 

196def restore_random_state(key: B.RandomState, state): 

197 """ 

198 Set the random state of a random generator. Return the new random 

199 generator with state `state`. 

200 

201 :param key: 

202 The random generator. 

203 :param state: 

204 The new random state of the random generator. 

205 """ 

206 

207 

208@dispatch 

209@abstract() 

210def create_complex(real: B.Numeric, imag: B.Numeric): 

211 """ 

212 Return a complex number with the given real and imaginary parts. 

213 

214 :param real: 

215 Array of any backend, real part of the complex number. 

216 :param imag: 

217 Array of any backend, imaginary part of the complex number. 

218 """ 

219 

220 

221@dispatch 

222@abstract() 

223def complex_like(reference: B.Numeric): 

224 """ 

225 Return `complex` dtype of a backend based on the reference. 

226 

227 :param reference: 

228 Array of any backend. 

229 """ 

230 

231 

232@dispatch 

233@abstract() 

234def is_complex(reference: B.Numeric): 

235 """ 

236 Return True if reference of `complex` dtype. 

237 

238 :param reference: 

239 Array of any backend. 

240 """ 

241 

242 

243@dispatch 

244@abstract() 

245def cumsum(a: B.Numeric, axis=None): 

246 """ 

247 Return cumulative sum (optionally along axis). 

248 

249 :param a: 

250 Array of any backend. 

251 :param axis: 

252 As in `numpy.cumsum`. 

253 """ 

254 

255 

256@dispatch 

257@abstract() 

258def qr(x: B.Numeric, mode="reduced"): 

259 """ 

260 Return a QR decomposition of a matrix x. 

261 

262 :param x: 

263 Array of any backend. 

264 :param mode: 

265 As in `numpy.linalg.qr`. 

266 """ 

267 

268 

269@dispatch 

270@abstract() 

271def slogdet(x: B.Numeric): 

272 """ 

273 Return the sign and log-determinant of a matrix x. 

274 

275 :param x: 

276 Array of any backend. 

277 """ 

278 

279 

280@dispatch 

281@abstract() 

282def eigvalsh(x: B.Numeric): 

283 """ 

284 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

285 

286 :param x: 

287 Array of any backend. 

288 """ 

289 

290 

291@dispatch 

292@abstract() 

293def reciprocal_no_nan(x: Union[B.Numeric, spmatrix]): 

294 """ 

295 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

296 

297 :param x: 

298 Array of any backend or `scipy.sparse.spmatrix`. 

299 """ 

300 

301 

302@dispatch 

303@abstract() 

304def complex_conj(x: B.Numeric): 

305 """ 

306 Return complex conjugate. 

307 

308 :param x: 

309 Array of any backend. 

310 """ 

311 

312 

313@dispatch 

314@abstract() 

315def logical_xor(x1: B.Bool, x2: B.Bool): 

316 """ 

317 Return logical XOR of two arrays. 

318 

319 :param x1: 

320 Array of any backend. 

321 :param x2: 

322 Array of any backend. 

323 """ 

324 

325 

326@dispatch 

327@abstract() 

328def count_nonzero(x: B.Numeric, axis=None): 

329 """ 

330 Count non-zero elements in an array. 

331 

332 :param x: 

333 Array of any backend and of any shape. 

334 """ 

335 

336 

337@dispatch 

338@abstract() 

339def dtype_bool(reference: B.RandomState): 

340 """ 

341 Return `bool` dtype of a backend based on the reference. 

342 

343 :param reference: 

344 A random state to infer the backend from. 

345 """ 

346 

347 

348@dispatch 

349@abstract() 

350def bool_like(reference: B.Numeric): 

351 """ 

352 Return the type of the reference if it is of boolean type. 

353 Otherwise return `bool` dtype of a backend based on the reference. 

354 

355 :param reference: 

356 Array of any backend. 

357 """ 

358 

359 

360def smart_cast( 

361 dtype: Union[B.Bool, B.Int, B.Float, B.Complex, B.Numeric], x: B.Numeric 

362): 

363 """ 

364 Return `x` cast to the `dtype` abstract data type. 

365 

366 :param dtype: 

367 An abstract DType of lab, one of `B.Bool`, `B.Int`, `B.Float`, 

368 `B.Complex`, `B.Numeric`. 

369 :param x: 

370 Array of any backend. 

371 """ 

372 if dtype == B.Bool: 

373 return B.cast(bool_like(x), x) 

374 elif dtype == B.Int: 

375 return B.cast(int_like(x), x) 

376 elif dtype == B.Float: 

377 return B.cast(float_like(x), x) 

378 elif dtype == B.Complex: 

379 return B.cast(complex_like(x), x)