Coverage for geometric_kernels / lab_extras / numpy / extras.py: 91%
102 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
1from typing import TypeAlias
3import lab as B
4import numpy as np
5from beartype.typing import Any
6from lab import dispatch
7from scipy.sparse import spmatrix
9_Numeric: TypeAlias = B.Number | B.NPNumeric
12@dispatch
13def take_along_axis(a: _Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore
14 """
15 Gathers elements of `a` along `axis` at `index` locations.
16 """
17 return np.take_along_axis(a, index, axis=axis)
20@dispatch
21def from_numpy(_: B.NPNumeric, b: list | B.NPNumeric | B.Number): # type: ignore
22 """
23 Converts the array `b` to a tensor of the same backend as `a`
24 """
25 return np.array(b)
28@dispatch
29def trapz(y: _Numeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore
30 """
31 Integrate along the given axis using the composite trapezoidal rule.
32 """
33 return np.trapz(y, x, dx, axis)
36@dispatch
37def norm(x: _Numeric, ord: Any | None = None, axis: int | None = None): # type: ignore
38 """
39 Matrix or vector norm.
40 """
41 return np.linalg.norm(x, ord=ord, axis=axis)
44@dispatch
45def logspace(start: _Numeric, stop: _Numeric, num: int = 50, base: _Numeric = 50.0): # type: ignore
46 """
47 Return numbers spaced evenly on a log scale.
48 """
49 return np.logspace(start, stop, num, base=base)
52@dispatch
53def degree(a: _Numeric): # type: ignore
54 """
55 Given an adjacency matrix `a`, return a diagonal matrix
56 with the col-sums of `a` as main diagonal - this is the
57 degree matrix representing the number of nodes each node
58 is connected to.
59 """
60 degrees = a.sum(axis=0) # type: ignore
61 return np.diag(degrees)
64@dispatch
65def dtype_double(reference: B.NPRandomState): # type: ignore
66 """
67 Return `double` dtype of a backend based on the reference.
68 """
69 return np.float64
72@dispatch
73def float_like(reference: B.NPNumeric):
74 """
75 Return the type of the reference if it is a floating point type.
76 Otherwise return `double` dtype of a backend based on the reference.
77 """
78 reference_dtype = reference.dtype
79 if np.issubdtype(reference_dtype, np.floating):
80 return reference_dtype
81 else:
82 return np.float64
85@dispatch
86def dtype_integer(reference: B.NPRandomState): # type: ignore
87 """
88 Return `int` dtype of a backend based on the reference.
89 """
90 return np.int32
93@dispatch
94def int_like(reference: B.NPNumeric):
95 reference_dtype = reference.dtype
96 if np.issubdtype(reference_dtype, np.integer):
97 return reference_dtype
98 else:
99 return np.int32
102@dispatch
103def get_random_state(key: B.NPRandomState):
104 """
105 Return the random state of a random generator.
107 :param key:
108 The random generator of type `B.NPRandomState`.
109 """
110 return key.get_state()
113@dispatch
114def restore_random_state(key: B.NPRandomState, state):
115 """
116 Set the random state of a random generator. Return the new random
117 generator with state `state`.
119 :param key:
120 The random generator of type `B.NPRandomState`.
121 :param state:
122 The new random state of the random generator.
123 """
124 gen = np.random.RandomState()
125 gen.set_state(state)
126 return gen
129@dispatch
130def create_complex(real: _Numeric, imag: _Numeric):
131 """
132 Return a complex number with the given real and imaginary parts using numpy.
134 :param real:
135 float, real part of the complex number.
136 :param imag:
137 float, imaginary part of the complex number.
138 """
139 complex_num = real + 1j * imag
140 return complex_num
143@dispatch
144def complex_like(reference: B.NPNumeric):
145 """
146 Return `complex` dtype of a backend based on the reference.
147 """
148 return B.promote_dtypes(np.complex64, reference.dtype)
151@dispatch
152def is_complex(reference: B.NPNumeric):
153 """
154 Return True if reference of `complex` dtype.
155 """
156 return (B.dtype(reference) == np.complex64) or (B.dtype(reference) == np.complex128)
159@dispatch
160def cumsum(a: _Numeric, axis=None):
161 """
162 Return cumulative sum (optionally along axis)
163 """
164 return np.cumsum(a, axis=axis)
167@dispatch
168def qr(x: _Numeric, mode="reduced"):
169 """
170 Return a QR decomposition of a matrix x.
171 """
172 Q, R = np.linalg.qr(x, mode=mode)
173 return Q, R
176@dispatch
177def slogdet(x: _Numeric):
178 """
179 Return the sign and log-determinant of a matrix x.
180 """
181 sign, logdet = np.linalg.slogdet(x)
182 return sign, logdet
185@dispatch
186def eigvalsh(x: _Numeric):
187 """
188 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
189 """
190 return np.linalg.eigvalsh(x, UPLO="U")
193@dispatch
194def reciprocal_no_nan(x: B.NPNumeric):
195 """
196 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
197 """
198 x_is_zero = np.equal(x, 0.0)
199 safe_x = np.where(x_is_zero, 1.0, x)
200 return np.where(x_is_zero, 0.0, np.reciprocal(safe_x))
203@dispatch # type: ignore[no-redef]
204def reciprocal_no_nan(x: spmatrix):
205 """
206 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
207 """
208 return x._with_data(reciprocal_no_nan(x._deduped_data().copy()), copy=True)
211@dispatch
212def complex_conj(x: B.NPNumeric):
213 """
214 Return complex conjugate
215 """
216 return np.conjugate(x)
219@dispatch
220def logical_xor(x1: B.NPNumeric, x2: B.NPNumeric):
221 """
222 Return logical XOR of two arrays.
223 """
224 return np.logical_xor(x1, x2)
227@dispatch
228def count_nonzero(x: B.NPNumeric, axis=None):
229 """
230 Count non-zero elements in an array.
231 """
232 return np.count_nonzero(x, axis=axis)
235@dispatch
236def dtype_bool(reference: B.NPRandomState): # type: ignore
237 """
238 Return `bool` dtype of a backend based on the reference.
239 """
240 return np.bool_
243@dispatch
244def bool_like(reference: B.NPNumeric):
245 """
246 Return the type of the reference if it is of boolean type.
247 Otherwise return `bool` dtype of a backend based on the reference.
248 """
249 reference_dtype = reference.dtype
250 if np.issubdtype(reference_dtype, np.bool_):
251 return reference_dtype
252 else:
253 return np.bool_