Coverage for geometric_kernels / lab_extras / extras.py: 99%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 15:49 +0000

1import lab as B 

2from lab import dispatch 

3from lab.util import abstract 

4from scipy.sparse import spmatrix 

5 

6 

7@dispatch 

8@abstract() 

9def take_along_axis(a: B.Numeric, index: B.Numeric, axis: int = 0): 

10 """ 

11 Gathers elements of `a` along `axis` at `index` locations. 

12 

13 :param a: 

14 Array of any backend, as in `numpy.take_along_axis`. 

15 :param index: 

16 Array of any backend, as in `numpy.take_along_axis`. 

17 :param axis: 

18 As in `numpy.take_along_axis`. 

19 """ 

20 

21 

22@dispatch 

23@abstract() 

24def from_numpy(_: B.Numeric, b: list | B.Numeric): 

25 """ 

26 Converts the array `b` to a tensor of the same backend as `_`. 

27 

28 :param _: 

29 Array of any backend used to determine the backend. 

30 :param b: 

31 Array of any backend or list to be converted to the backend of _. 

32 """ 

33 

34 

35@dispatch 

36@abstract() 

37def trapz(y: B.Numeric, x: B.Numeric, dx: B.Numeric = 1.0, axis: int = -1): 

38 """ 

39 Integrate along the given axis using the trapezoidal rule. 

40 

41 :param y: 

42 Array of any backend, as in `numpy.trapz`. 

43 :param x: 

44 Array of any backend, as in `numpy.trapz`. 

45 :param dx: 

46 Array of any backend, as in `numpy.trapz`. 

47 :param axis: 

48 As in `numpy.trapz`. 

49 """ 

50 

51 

52@dispatch 

53@abstract() 

54def logspace(start: B.Numeric, stop: B.Numeric, num: int = 50): 

55 """ 

56 Return numbers spaced evenly on a log scale. 

57 

58 :param start: 

59 Array of any backend, as in `numpy.logspace`. 

60 :param stop: 

61 Array of any backend, as in `numpy.logspace`. 

62 :param num: 

63 As in `numpy.logspace`. 

64 """ 

65 

66 

67def cosh(x: B.Numeric) -> B.Numeric: 

68 r""" 

69 Compute hyperbolic cosine using the formula 

70 

71 .. math:: \textrm{cosh}(x) = \frac{\exp(x) + \exp(-x)}{2}. 

72 

73 :param x: 

74 Array of any backend. 

75 """ 

76 return 0.5 * (B.exp(x) + B.exp(-x)) 

77 

78 

79def sinh(x: B.Numeric) -> B.Numeric: 

80 r""" 

81 Compute hyperbolic sine using the formula 

82 

83 .. math:: \textrm{sinh}(x) = \frac{\exp(x) - \exp(-x)}{2}. 

84 

85 :param x: 

86 Array of any backend. 

87 """ 

88 return 0.5 * (B.exp(x) - B.exp(-x)) 

89 

90 

91@dispatch 

92@abstract() 

93def degree(a): 

94 """ 

95 Given an adjacency matrix `a`, return a diagonal matrix 

96 with the col-sums of `a` as main diagonal - this is the 

97 degree matrix representing the number of nodes each node 

98 is connected to. 

99 

100 :param a: 

101 Array of any backend or `scipy.sparse` array. 

102 """ 

103 

104 

105@dispatch 

106@abstract() 

107def eigenpairs(L, k: int): 

108 """ 

109 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues 

110 of a symmetric positive semi-definite matrix `L`. 

111 

112 :param a: 

113 Array of any backend or `scipy.sparse` array. 

114 :param k: 

115 The number of eigenpairs to compute. 

116 """ 

117 

118 

119@dispatch 

120@abstract() 

121def set_value(a, index: int, value: float): 

122 """ 

123 Set a[index] = value. 

124 This operation is not done in place and a new array is returned. 

125 

126 :param a: 

127 Array of any backend or `scipy.sparse` array. 

128 :param index: 

129 The index. 

130 :param value: 

131 The value to set at the given index. 

132 """ 

133 

134 

135@dispatch 

136@abstract() 

137def dtype_double(reference: B.RandomState): 

138 """ 

139 Return `double` dtype of a backend based on the reference. 

140 

141 :param reference: 

142 A random state to infer the backend from. 

143 """ 

144 

145 

146@dispatch 

147@abstract() 

148def float_like(reference: B.Numeric): 

149 """ 

150 Return the type of the reference if it is a floating point type. 

151 Otherwise return `double` dtype of a backend based on the reference. 

152 

153 :param reference: 

154 Array of any backend. 

155 """ 

156 

157 

158@dispatch 

159@abstract() 

160def dtype_integer(reference: B.RandomState): 

161 """ 

162 Return `int` dtype of a backend based on the reference. 

163 

164 :param reference: 

165 A random state to infer the backend from. 

166 """ 

167 

168 

169@dispatch 

170@abstract() 

171def int_like(reference: B.Numeric): 

172 """ 

173 Return the type of the reference if it is integer type. 

174 Otherwise return `int32` dtype of a backend based on the reference. 

175 

176 :param reference: 

177 Array of any backend. 

178 """ 

179 

180 

181@dispatch 

182@abstract() 

183def get_random_state(key: B.RandomState): 

184 """ 

185 Return the random state of a random generator. 

186 

187 :param key: 

188 The random generator. 

189 """ 

190 

191 

192@dispatch 

193@abstract() 

194def restore_random_state(key: B.RandomState, state): 

195 """ 

196 Set the random state of a random generator. Return the new random 

197 generator with state `state`. 

198 

199 :param key: 

200 The random generator. 

201 :param state: 

202 The new random state of the random generator. 

203 """ 

204 

205 

206@dispatch 

207@abstract() 

208def create_complex(real: B.Numeric, imag: B.Numeric): 

209 """ 

210 Return a complex number with the given real and imaginary parts. 

211 

212 :param real: 

213 Array of any backend, real part of the complex number. 

214 :param imag: 

215 Array of any backend, imaginary part of the complex number. 

216 """ 

217 

218 

219@dispatch 

220@abstract() 

221def complex_like(reference: B.Numeric): 

222 """ 

223 Return `complex` dtype of a backend based on the reference. 

224 

225 :param reference: 

226 Array of any backend. 

227 """ 

228 

229 

230@dispatch 

231@abstract() 

232def is_complex(reference: B.Numeric): 

233 """ 

234 Return True if reference of `complex` dtype. 

235 

236 :param reference: 

237 Array of any backend. 

238 """ 

239 

240 

241@dispatch 

242@abstract() 

243def cumsum(a: B.Numeric, axis=None): 

244 """ 

245 Return cumulative sum (optionally along axis). 

246 

247 :param a: 

248 Array of any backend. 

249 :param axis: 

250 As in `numpy.cumsum`. 

251 """ 

252 

253 

254@dispatch 

255@abstract() 

256def qr(x: B.Numeric, mode="reduced"): 

257 """ 

258 Return a QR decomposition of a matrix x. 

259 

260 :param x: 

261 Array of any backend. 

262 :param mode: 

263 As in `numpy.linalg.qr`. 

264 """ 

265 

266 

267@dispatch 

268@abstract() 

269def slogdet(x: B.Numeric): 

270 """ 

271 Return the sign and log-determinant of a matrix x. 

272 

273 :param x: 

274 Array of any backend. 

275 """ 

276 

277 

278@dispatch 

279@abstract() 

280def eigvalsh(x: B.Numeric): 

281 """ 

282 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

283 

284 :param x: 

285 Array of any backend. 

286 """ 

287 

288 

289@dispatch 

290@abstract() 

291def reciprocal_no_nan(x: B.Numeric | spmatrix): 

292 """ 

293 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

294 

295 :param x: 

296 Array of any backend or `scipy.sparse.spmatrix`. 

297 """ 

298 

299 

300@dispatch 

301@abstract() 

302def complex_conj(x: B.Numeric): 

303 """ 

304 Return complex conjugate. 

305 

306 :param x: 

307 Array of any backend. 

308 """ 

309 

310 

311@dispatch 

312@abstract() 

313def logical_xor(x1: B.Bool, x2: B.Bool): 

314 """ 

315 Return logical XOR of two arrays. 

316 

317 :param x1: 

318 Array of any backend. 

319 :param x2: 

320 Array of any backend. 

321 """ 

322 

323 

324@dispatch 

325@abstract() 

326def count_nonzero(x: B.Numeric, axis=None): 

327 """ 

328 Count non-zero elements in an array. 

329 

330 :param x: 

331 Array of any backend and of any shape. 

332 """ 

333 

334 

335@dispatch 

336@abstract() 

337def dtype_bool(reference: B.RandomState): 

338 """ 

339 Return `bool` dtype of a backend based on the reference. 

340 

341 :param reference: 

342 A random state to infer the backend from. 

343 """ 

344 

345 

346@dispatch 

347@abstract() 

348def bool_like(reference: B.Numeric): 

349 """ 

350 Return the type of the reference if it is of boolean type. 

351 Otherwise return `bool` dtype of a backend based on the reference. 

352 

353 :param reference: 

354 Array of any backend. 

355 """ 

356 

357 

358def smart_cast(dtype: B.Bool | B.Int | B.Float | B.Complex | B.Numeric, x: B.Numeric): 

359 """ 

360 Return `x` cast to the `dtype` abstract data type. 

361 

362 :param dtype: 

363 An abstract DType of lab, one of `B.Bool`, `B.Int`, `B.Float`, 

364 `B.Complex`, `B.Numeric`. 

365 :param x: 

366 Array of any backend. 

367 """ 

368 if dtype == B.Bool: 

369 return B.cast(bool_like(x), x) 

370 elif dtype == B.Int: 

371 return B.cast(int_like(x), x) 

372 elif dtype == B.Float: 

373 return B.cast(float_like(x), x) 

374 elif dtype == B.Complex: 

375 return B.cast(complex_like(x), x)