Coverage for geometric_kernels / lab_extras / extras.py: 99%
95 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
1import lab as B
2from lab import dispatch
3from lab.util import abstract
4from scipy.sparse import spmatrix
7@dispatch
8@abstract()
9def take_along_axis(a: B.Numeric, index: B.Numeric, axis: int = 0):
10 """
11 Gathers elements of `a` along `axis` at `index` locations.
13 :param a:
14 Array of any backend, as in `numpy.take_along_axis`.
15 :param index:
16 Array of any backend, as in `numpy.take_along_axis`.
17 :param axis:
18 As in `numpy.take_along_axis`.
19 """
22@dispatch
23@abstract()
24def from_numpy(_: B.Numeric, b: list | B.Numeric):
25 """
26 Converts the array `b` to a tensor of the same backend as `_`.
28 :param _:
29 Array of any backend used to determine the backend.
30 :param b:
31 Array of any backend or list to be converted to the backend of _.
32 """
35@dispatch
36@abstract()
37def trapz(y: B.Numeric, x: B.Numeric, dx: B.Numeric = 1.0, axis: int = -1):
38 """
39 Integrate along the given axis using the trapezoidal rule.
41 :param y:
42 Array of any backend, as in `numpy.trapz`.
43 :param x:
44 Array of any backend, as in `numpy.trapz`.
45 :param dx:
46 Array of any backend, as in `numpy.trapz`.
47 :param axis:
48 As in `numpy.trapz`.
49 """
52@dispatch
53@abstract()
54def logspace(start: B.Numeric, stop: B.Numeric, num: int = 50):
55 """
56 Return numbers spaced evenly on a log scale.
58 :param start:
59 Array of any backend, as in `numpy.logspace`.
60 :param stop:
61 Array of any backend, as in `numpy.logspace`.
62 :param num:
63 As in `numpy.logspace`.
64 """
67def cosh(x: B.Numeric) -> B.Numeric:
68 r"""
69 Compute hyperbolic cosine using the formula
71 .. math:: \textrm{cosh}(x) = \frac{\exp(x) + \exp(-x)}{2}.
73 :param x:
74 Array of any backend.
75 """
76 return 0.5 * (B.exp(x) + B.exp(-x))
79def sinh(x: B.Numeric) -> B.Numeric:
80 r"""
81 Compute hyperbolic sine using the formula
83 .. math:: \textrm{sinh}(x) = \frac{\exp(x) - \exp(-x)}{2}.
85 :param x:
86 Array of any backend.
87 """
88 return 0.5 * (B.exp(x) - B.exp(-x))
91@dispatch
92@abstract()
93def degree(a):
94 """
95 Given an adjacency matrix `a`, return a diagonal matrix
96 with the col-sums of `a` as main diagonal - this is the
97 degree matrix representing the number of nodes each node
98 is connected to.
100 :param a:
101 Array of any backend or `scipy.sparse` array.
102 """
105@dispatch
106@abstract()
107def eigenpairs(L, k: int):
108 """
109 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues
110 of a symmetric positive semi-definite matrix `L`.
112 :param a:
113 Array of any backend or `scipy.sparse` array.
114 :param k:
115 The number of eigenpairs to compute.
116 """
119@dispatch
120@abstract()
121def set_value(a, index: int, value: float):
122 """
123 Set a[index] = value.
124 This operation is not done in place and a new array is returned.
126 :param a:
127 Array of any backend or `scipy.sparse` array.
128 :param index:
129 The index.
130 :param value:
131 The value to set at the given index.
132 """
135@dispatch
136@abstract()
137def dtype_double(reference: B.RandomState):
138 """
139 Return `double` dtype of a backend based on the reference.
141 :param reference:
142 A random state to infer the backend from.
143 """
146@dispatch
147@abstract()
148def float_like(reference: B.Numeric):
149 """
150 Return the type of the reference if it is a floating point type.
151 Otherwise return `double` dtype of a backend based on the reference.
153 :param reference:
154 Array of any backend.
155 """
158@dispatch
159@abstract()
160def dtype_integer(reference: B.RandomState):
161 """
162 Return `int` dtype of a backend based on the reference.
164 :param reference:
165 A random state to infer the backend from.
166 """
169@dispatch
170@abstract()
171def int_like(reference: B.Numeric):
172 """
173 Return the type of the reference if it is integer type.
174 Otherwise return `int32` dtype of a backend based on the reference.
176 :param reference:
177 Array of any backend.
178 """
181@dispatch
182@abstract()
183def get_random_state(key: B.RandomState):
184 """
185 Return the random state of a random generator.
187 :param key:
188 The random generator.
189 """
192@dispatch
193@abstract()
194def restore_random_state(key: B.RandomState, state):
195 """
196 Set the random state of a random generator. Return the new random
197 generator with state `state`.
199 :param key:
200 The random generator.
201 :param state:
202 The new random state of the random generator.
203 """
206@dispatch
207@abstract()
208def create_complex(real: B.Numeric, imag: B.Numeric):
209 """
210 Return a complex number with the given real and imaginary parts.
212 :param real:
213 Array of any backend, real part of the complex number.
214 :param imag:
215 Array of any backend, imaginary part of the complex number.
216 """
219@dispatch
220@abstract()
221def complex_like(reference: B.Numeric):
222 """
223 Return `complex` dtype of a backend based on the reference.
225 :param reference:
226 Array of any backend.
227 """
230@dispatch
231@abstract()
232def is_complex(reference: B.Numeric):
233 """
234 Return True if reference of `complex` dtype.
236 :param reference:
237 Array of any backend.
238 """
241@dispatch
242@abstract()
243def cumsum(a: B.Numeric, axis=None):
244 """
245 Return cumulative sum (optionally along axis).
247 :param a:
248 Array of any backend.
249 :param axis:
250 As in `numpy.cumsum`.
251 """
254@dispatch
255@abstract()
256def qr(x: B.Numeric, mode="reduced"):
257 """
258 Return a QR decomposition of a matrix x.
260 :param x:
261 Array of any backend.
262 :param mode:
263 As in `numpy.linalg.qr`.
264 """
267@dispatch
268@abstract()
269def slogdet(x: B.Numeric):
270 """
271 Return the sign and log-determinant of a matrix x.
273 :param x:
274 Array of any backend.
275 """
278@dispatch
279@abstract()
280def eigvalsh(x: B.Numeric):
281 """
282 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
284 :param x:
285 Array of any backend.
286 """
289@dispatch
290@abstract()
291def reciprocal_no_nan(x: B.Numeric | spmatrix):
292 """
293 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
295 :param x:
296 Array of any backend or `scipy.sparse.spmatrix`.
297 """
300@dispatch
301@abstract()
302def complex_conj(x: B.Numeric):
303 """
304 Return complex conjugate.
306 :param x:
307 Array of any backend.
308 """
311@dispatch
312@abstract()
313def logical_xor(x1: B.Bool, x2: B.Bool):
314 """
315 Return logical XOR of two arrays.
317 :param x1:
318 Array of any backend.
319 :param x2:
320 Array of any backend.
321 """
324@dispatch
325@abstract()
326def count_nonzero(x: B.Numeric, axis=None):
327 """
328 Count non-zero elements in an array.
330 :param x:
331 Array of any backend and of any shape.
332 """
335@dispatch
336@abstract()
337def dtype_bool(reference: B.RandomState):
338 """
339 Return `bool` dtype of a backend based on the reference.
341 :param reference:
342 A random state to infer the backend from.
343 """
346@dispatch
347@abstract()
348def bool_like(reference: B.Numeric):
349 """
350 Return the type of the reference if it is of boolean type.
351 Otherwise return `bool` dtype of a backend based on the reference.
353 :param reference:
354 Array of any backend.
355 """
358def smart_cast(dtype: B.Bool | B.Int | B.Float | B.Complex | B.Numeric, x: B.Numeric):
359 """
360 Return `x` cast to the `dtype` abstract data type.
362 :param dtype:
363 An abstract DType of lab, one of `B.Bool`, `B.Int`, `B.Float`,
364 `B.Complex`, `B.Numeric`.
365 :param x:
366 Array of any backend.
367 """
368 if dtype == B.Bool:
369 return B.cast(bool_like(x), x)
370 elif dtype == B.Int:
371 return B.cast(int_like(x), x)
372 elif dtype == B.Float:
373 return B.cast(float_like(x), x)
374 elif dtype == B.Complex:
375 return B.cast(complex_like(x), x)