Coverage for geometric_kernels/lab_extras/numpy/extras.py: 92%
102 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import lab as B
2import numpy as np
3from beartype.typing import Any, List, Optional
4from lab import dispatch
5from plum import Union
6from scipy.sparse import spmatrix
8_Numeric = Union[B.Number, B.NPNumeric]
11@dispatch
12def take_along_axis(a: _Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore
13 """
14 Gathers elements of `a` along `axis` at `index` locations.
15 """
16 return np.take_along_axis(a, index, axis=axis)
19@dispatch
20def from_numpy(_: B.NPNumeric, b: Union[List, B.NPNumeric, B.Number]): # type: ignore
21 """
22 Converts the array `b` to a tensor of the same backend as `a`
23 """
24 return np.array(b)
27@dispatch
28def trapz(y: _Numeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore
29 """
30 Integrate along the given axis using the composite trapezoidal rule.
31 """
32 return np.trapz(y, x, dx, axis)
35@dispatch
36def norm(x: _Numeric, ord: Optional[Any] = None, axis: Optional[int] = None): # type: ignore
37 """
38 Matrix or vector norm.
39 """
40 return np.linalg.norm(x, ord=ord, axis=axis)
43@dispatch
44def logspace(start: _Numeric, stop: _Numeric, num: int = 50, base: _Numeric = 50.0): # type: ignore
45 """
46 Return numbers spaced evenly on a log scale.
47 """
48 return np.logspace(start, stop, num, base=base)
51@dispatch
52def degree(a: _Numeric): # type: ignore
53 """
54 Given an adjacency matrix `a`, return a diagonal matrix
55 with the col-sums of `a` as main diagonal - this is the
56 degree matrix representing the number of nodes each node
57 is connected to.
58 """
59 degrees = a.sum(axis=0) # type: ignore
60 return np.diag(degrees)
63@dispatch
64def dtype_double(reference: B.NPRandomState): # type: ignore
65 """
66 Return `double` dtype of a backend based on the reference.
67 """
68 return np.float64
71@dispatch
72def float_like(reference: B.NPNumeric):
73 """
74 Return the type of the reference if it is a floating point type.
75 Otherwise return `double` dtype of a backend based on the reference.
76 """
77 reference_dtype = reference.dtype
78 if np.issubdtype(reference_dtype, np.floating):
79 return reference_dtype
80 else:
81 return np.float64
84@dispatch
85def dtype_integer(reference: B.NPRandomState): # type: ignore
86 """
87 Return `int` dtype of a backend based on the reference.
88 """
89 return np.int32
92@dispatch
93def int_like(reference: B.NPNumeric):
94 reference_dtype = reference.dtype
95 if np.issubdtype(reference_dtype, np.integer):
96 return reference_dtype
97 else:
98 return np.int32
101@dispatch
102def get_random_state(key: B.NPRandomState):
103 """
104 Return the random state of a random generator.
106 :param key:
107 The random generator of type `B.NPRandomState`.
108 """
109 return key.get_state()
112@dispatch
113def restore_random_state(key: B.NPRandomState, state):
114 """
115 Set the random state of a random generator. Return the new random
116 generator with state `state`.
118 :param key:
119 The random generator of type `B.NPRandomState`.
120 :param state:
121 The new random state of the random generator.
122 """
123 gen = np.random.RandomState()
124 gen.set_state(state)
125 return gen
128@dispatch
129def create_complex(real: _Numeric, imag: _Numeric):
130 """
131 Return a complex number with the given real and imaginary parts using numpy.
133 :param real:
134 float, real part of the complex number.
135 :param imag:
136 float, imaginary part of the complex number.
137 """
138 complex_num = real + 1j * imag
139 return complex_num
142@dispatch
143def complex_like(reference: B.NPNumeric):
144 """
145 Return `complex` dtype of a backend based on the reference.
146 """
147 return B.promote_dtypes(np.complex64, reference.dtype)
150@dispatch
151def is_complex(reference: B.NPNumeric):
152 """
153 Return True if reference of `complex` dtype.
154 """
155 return (B.dtype(reference) == np.complex64) or (B.dtype(reference) == np.complex128)
158@dispatch
159def cumsum(a: _Numeric, axis=None):
160 """
161 Return cumulative sum (optionally along axis)
162 """
163 return np.cumsum(a, axis=axis)
166@dispatch
167def qr(x: _Numeric, mode="reduced"):
168 """
169 Return a QR decomposition of a matrix x.
170 """
171 Q, R = np.linalg.qr(x, mode=mode)
172 return Q, R
175@dispatch
176def slogdet(x: _Numeric):
177 """
178 Return the sign and log-determinant of a matrix x.
179 """
180 sign, logdet = np.linalg.slogdet(x)
181 return sign, logdet
184@dispatch
185def eigvalsh(x: _Numeric):
186 """
187 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
188 """
189 return np.linalg.eigvalsh(x, UPLO="U")
192@dispatch
193def reciprocal_no_nan(x: B.NPNumeric):
194 """
195 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
196 """
197 x_is_zero = np.equal(x, 0.0)
198 safe_x = np.where(x_is_zero, 1.0, x)
199 return np.where(x_is_zero, 0.0, np.reciprocal(safe_x))
202@dispatch # type: ignore[no-redef]
203def reciprocal_no_nan(x: spmatrix):
204 """
205 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
206 """
207 return x._with_data(reciprocal_no_nan(x._deduped_data().copy()), copy=True)
210@dispatch
211def complex_conj(x: B.NPNumeric):
212 """
213 Return complex conjugate
214 """
215 return np.conjugate(x)
218@dispatch
219def logical_xor(x1: B.NPNumeric, x2: B.NPNumeric):
220 """
221 Return logical XOR of two arrays.
222 """
223 return np.logical_xor(x1, x2)
226@dispatch
227def count_nonzero(x: B.NPNumeric, axis=None):
228 """
229 Count non-zero elements in an array.
230 """
231 return np.count_nonzero(x, axis=axis)
234@dispatch
235def dtype_bool(reference: B.NPRandomState): # type: ignore
236 """
237 Return `bool` dtype of a backend based on the reference.
238 """
239 return np.bool_
242@dispatch
243def bool_like(reference: B.NPNumeric):
244 """
245 Return the type of the reference if it is of boolean type.
246 Otherwise return `bool` dtype of a backend based on the reference.
247 """
248 reference_dtype = reference.dtype
249 if np.issubdtype(reference_dtype, np.bool_):
250 return reference_dtype
251 else:
252 return np.bool_