Coverage for geometric_kernels/lab_extras/extras.py: 99%
97 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import lab as B
2from beartype.typing import List
3from lab import dispatch
4from lab.util import abstract
5from plum import Union
6from scipy.sparse import spmatrix
9@dispatch
10@abstract()
11def take_along_axis(a: B.Numeric, index: B.Numeric, axis: int = 0):
12 """
13 Gathers elements of `a` along `axis` at `index` locations.
15 :param a:
16 Array of any backend, as in `numpy.take_along_axis`.
17 :param index:
18 Array of any backend, as in `numpy.take_along_axis`.
19 :param axis:
20 As in `numpy.take_along_axis`.
21 """
24@dispatch
25@abstract()
26def from_numpy(_: B.Numeric, b: Union[List, B.Numeric]):
27 """
28 Converts the array `b` to a tensor of the same backend as `_`.
30 :param _:
31 Array of any backend used to determine the backend.
32 :param b:
33 Array of any backend or list to be converted to the backend of _.
34 """
37@dispatch
38@abstract()
39def trapz(y: B.Numeric, x: B.Numeric, dx: B.Numeric = 1.0, axis: int = -1):
40 """
41 Integrate along the given axis using the trapezoidal rule.
43 :param y:
44 Array of any backend, as in `numpy.trapz`.
45 :param x:
46 Array of any backend, as in `numpy.trapz`.
47 :param dx:
48 Array of any backend, as in `numpy.trapz`.
49 :param axis:
50 As in `numpy.trapz`.
51 """
54@dispatch
55@abstract()
56def logspace(start: B.Numeric, stop: B.Numeric, num: int = 50):
57 """
58 Return numbers spaced evenly on a log scale.
60 :param start:
61 Array of any backend, as in `numpy.logspace`.
62 :param stop:
63 Array of any backend, as in `numpy.logspace`.
64 :param num:
65 As in `numpy.logspace`.
66 """
69def cosh(x: B.Numeric) -> B.Numeric:
70 r"""
71 Compute hyperbolic cosine using the formula
73 .. math:: \textrm{cosh}(x) = \frac{\exp(x) + \exp(-x)}{2}.
75 :param x:
76 Array of any backend.
77 """
78 return 0.5 * (B.exp(x) + B.exp(-x))
81def sinh(x: B.Numeric) -> B.Numeric:
82 r"""
83 Compute hyperbolic sine using the formula
85 .. math:: \textrm{sinh}(x) = \frac{\exp(x) - \exp(-x)}{2}.
87 :param x:
88 Array of any backend.
89 """
90 return 0.5 * (B.exp(x) - B.exp(-x))
93@dispatch
94@abstract()
95def degree(a):
96 """
97 Given an adjacency matrix `a`, return a diagonal matrix
98 with the col-sums of `a` as main diagonal - this is the
99 degree matrix representing the number of nodes each node
100 is connected to.
102 :param a:
103 Array of any backend or `scipy.sparse` array.
104 """
107@dispatch
108@abstract()
109def eigenpairs(L, k: int):
110 """
111 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues
112 of a symmetric positive semi-definite matrix `L`.
114 :param a:
115 Array of any backend or `scipy.sparse` array.
116 :param k:
117 The number of eigenpairs to compute.
118 """
121@dispatch
122@abstract()
123def set_value(a, index: int, value: float):
124 """
125 Set a[index] = value.
126 This operation is not done in place and a new array is returned.
128 :param a:
129 Array of any backend or `scipy.sparse` array.
130 :param index:
131 The index.
132 :param value:
133 The value to set at the given index.
134 """
137@dispatch
138@abstract()
139def dtype_double(reference: B.RandomState):
140 """
141 Return `double` dtype of a backend based on the reference.
143 :param reference:
144 A random state to infer the backend from.
145 """
148@dispatch
149@abstract()
150def float_like(reference: B.Numeric):
151 """
152 Return the type of the reference if it is a floating point type.
153 Otherwise return `double` dtype of a backend based on the reference.
155 :param reference:
156 Array of any backend.
157 """
160@dispatch
161@abstract()
162def dtype_integer(reference: B.RandomState):
163 """
164 Return `int` dtype of a backend based on the reference.
166 :param reference:
167 A random state to infer the backend from.
168 """
171@dispatch
172@abstract()
173def int_like(reference: B.Numeric):
174 """
175 Return the type of the reference if it is integer type.
176 Otherwise return `int32` dtype of a backend based on the reference.
178 :param reference:
179 Array of any backend.
180 """
183@dispatch
184@abstract()
185def get_random_state(key: B.RandomState):
186 """
187 Return the random state of a random generator.
189 :param key:
190 The random generator.
191 """
194@dispatch
195@abstract()
196def restore_random_state(key: B.RandomState, state):
197 """
198 Set the random state of a random generator. Return the new random
199 generator with state `state`.
201 :param key:
202 The random generator.
203 :param state:
204 The new random state of the random generator.
205 """
208@dispatch
209@abstract()
210def create_complex(real: B.Numeric, imag: B.Numeric):
211 """
212 Return a complex number with the given real and imaginary parts.
214 :param real:
215 Array of any backend, real part of the complex number.
216 :param imag:
217 Array of any backend, imaginary part of the complex number.
218 """
221@dispatch
222@abstract()
223def complex_like(reference: B.Numeric):
224 """
225 Return `complex` dtype of a backend based on the reference.
227 :param reference:
228 Array of any backend.
229 """
232@dispatch
233@abstract()
234def is_complex(reference: B.Numeric):
235 """
236 Return True if reference of `complex` dtype.
238 :param reference:
239 Array of any backend.
240 """
243@dispatch
244@abstract()
245def cumsum(a: B.Numeric, axis=None):
246 """
247 Return cumulative sum (optionally along axis).
249 :param a:
250 Array of any backend.
251 :param axis:
252 As in `numpy.cumsum`.
253 """
256@dispatch
257@abstract()
258def qr(x: B.Numeric, mode="reduced"):
259 """
260 Return a QR decomposition of a matrix x.
262 :param x:
263 Array of any backend.
264 :param mode:
265 As in `numpy.linalg.qr`.
266 """
269@dispatch
270@abstract()
271def slogdet(x: B.Numeric):
272 """
273 Return the sign and log-determinant of a matrix x.
275 :param x:
276 Array of any backend.
277 """
280@dispatch
281@abstract()
282def eigvalsh(x: B.Numeric):
283 """
284 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
286 :param x:
287 Array of any backend.
288 """
291@dispatch
292@abstract()
293def reciprocal_no_nan(x: Union[B.Numeric, spmatrix]):
294 """
295 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
297 :param x:
298 Array of any backend or `scipy.sparse.spmatrix`.
299 """
302@dispatch
303@abstract()
304def complex_conj(x: B.Numeric):
305 """
306 Return complex conjugate.
308 :param x:
309 Array of any backend.
310 """
313@dispatch
314@abstract()
315def logical_xor(x1: B.Bool, x2: B.Bool):
316 """
317 Return logical XOR of two arrays.
319 :param x1:
320 Array of any backend.
321 :param x2:
322 Array of any backend.
323 """
326@dispatch
327@abstract()
328def count_nonzero(x: B.Numeric, axis=None):
329 """
330 Count non-zero elements in an array.
332 :param x:
333 Array of any backend and of any shape.
334 """
337@dispatch
338@abstract()
339def dtype_bool(reference: B.RandomState):
340 """
341 Return `bool` dtype of a backend based on the reference.
343 :param reference:
344 A random state to infer the backend from.
345 """
348@dispatch
349@abstract()
350def bool_like(reference: B.Numeric):
351 """
352 Return the type of the reference if it is of boolean type.
353 Otherwise return `bool` dtype of a backend based on the reference.
355 :param reference:
356 Array of any backend.
357 """
360def smart_cast(
361 dtype: Union[B.Bool, B.Int, B.Float, B.Complex, B.Numeric], x: B.Numeric
362):
363 """
364 Return `x` cast to the `dtype` abstract data type.
366 :param dtype:
367 An abstract DType of lab, one of `B.Bool`, `B.Int`, `B.Float`,
368 `B.Complex`, `B.Numeric`.
369 :param x:
370 Array of any backend.
371 """
372 if dtype == B.Bool:
373 return B.cast(bool_like(x), x)
374 elif dtype == B.Int:
375 return B.cast(int_like(x), x)
376 elif dtype == B.Float:
377 return B.cast(float_like(x), x)
378 elif dtype == B.Complex:
379 return B.cast(complex_like(x), x)