Coverage for geometric_kernels/lab_extras/jax/extras.py: 93%
103 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import jax.numpy as jnp
2import lab as B
3from beartype.typing import List
4from lab import dispatch
5from plum import Union, convert
7_Numeric = Union[B.Number, B.JAXNumeric]
10@dispatch
11def take_along_axis(a: Union[_Numeric, B.Numeric], index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore
12 """
13 Gathers elements of `a` along `axis` at `index` locations.
14 """
15 if not isinstance(a, jnp.ndarray):
16 a = jnp.array(a)
17 return jnp.take_along_axis(a, index, axis=axis)
20@dispatch
21def from_numpy(_: B.JAXNumeric, b: Union[List, B.NPNumeric, B.Number, B.JAXNumeric]): # type: ignore
22 """
23 Converts the array `b` to a tensor of the same backend as `a`
24 """
25 return jnp.array(b)
28@dispatch
29def trapz(y: B.JAXNumeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore
30 """
31 Integrate along the given axis using the trapezoidal rule.
32 """
33 return jnp.trapezoid(y, x, dx, axis)
36@dispatch
37def logspace(start: B.JAXNumeric, stop: _Numeric, num: int = 50): # type: ignore
38 """
39 Return numbers spaced evenly on a log scale.
40 """
41 return jnp.logspace(start, stop, num)
44@dispatch
45def degree(a: B.JAXNumeric): # type: ignore
46 """
47 Given an adjacency matrix `a`, return a diagonal matrix
48 with the col-sums of `a` as main diagonal - this is the
49 degree matrix representing the number of nodes each node
50 is connected to.
51 """
52 degrees = a.sum(axis=0) # type: ignore
53 return jnp.diag(degrees)
56@dispatch
57def eigenpairs(L: B.JAXNumeric, k: int):
58 """
59 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues
60 of a symmetric positive semi-definite matrix `L`.
61 """
62 l, u = jnp.linalg.eigh(L)
63 return l[:k], u[:, :k]
66@dispatch
67def set_value(a: B.JAXNumeric, index: int, value: float):
68 """
69 Set a[index] = value.
70 This operation is not done in place and a new array is returned.
71 """
72 a = a.at[index].set(value)
73 return a
76@dispatch
77def dtype_double(reference: B.JAXRandomState): # type: ignore
78 """
79 Return `double` dtype of a backend based on the reference.
80 """
81 return jnp.float64
84@dispatch
85def float_like(reference: B.JAXNumeric):
86 """
87 Return the type of the reference if it is a floating point type.
88 Otherwise return `double` dtype of a backend based on the reference.
89 """
90 reference_dtype = reference.dtype
91 if jnp.issubdtype(reference_dtype, jnp.floating):
92 return convert(
93 reference_dtype, B.JAXDType
94 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one.
95 else:
96 return jnp.float64
99@dispatch
100def dtype_integer(reference: B.JAXRandomState): # type: ignore
101 """
102 Return `int` dtype of a backend based on the reference.
103 """
104 return jnp.int32
107@dispatch
108def int_like(reference: B.JAXNumeric):
109 reference_dtype = reference.dtype
110 if jnp.issubdtype(reference_dtype, jnp.integer):
111 return convert(
112 reference_dtype, B.JAXDType
113 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one.
114 else:
115 return jnp.int32
118@dispatch
119def get_random_state(key: B.JAXRandomState):
120 """
121 Return the random state of a random generator.
123 :param key:
124 The random generator of type `B.JAXRandomState`.
125 """
126 return key
129@dispatch
130def restore_random_state(key: B.JAXRandomState, state):
131 """
132 Set the random state of a random generator. Return the new random
133 generator with state `state`.
135 :param key:
136 The random generator of type `B.JAXRandomState`.
137 :param state:
138 The new random state of the random generator.
139 """
140 return state
143@dispatch
144def create_complex(real: _Numeric, imag: B.JAXNumeric):
145 """
146 Return a complex number with the given real and imaginary parts using jax.
148 :param real:
149 float, real part of the complex number.
150 :param imag:
151 float, imaginary part of the complex number.
152 """
153 complex_num = real + 1j * imag
154 return complex_num
157@dispatch
158def complex_like(reference: B.JAXNumeric):
159 """
160 Return `complex` dtype of a backend based on the reference.
161 """
162 return B.promote_dtypes(jnp.complex64, reference.dtype)
165@dispatch
166def is_complex(reference: B.JAXNumeric):
167 """
168 Return True if reference of `complex` dtype.
169 """
170 return (B.dtype(reference) == jnp.complex64) or (
171 B.dtype(reference) == jnp.complex128
172 )
175@dispatch
176def cumsum(x: B.JAXNumeric, axis=None):
177 """
178 Return cumulative sum (optionally along axis)
179 """
180 return jnp.cumsum(x, axis=axis)
183@dispatch
184def qr(x: B.JAXNumeric, mode="reduced"):
185 """
186 Return a QR decomposition of a matrix x.
187 """
188 Q, R = jnp.linalg.qr(x, mode=mode)
189 return Q, R
192@dispatch
193def slogdet(x: B.JAXNumeric):
194 """
195 Return the sign and log-determinant of a matrix x.
196 """
197 sign, logdet = jnp.linalg.slogdet(x)
198 return sign, logdet
201@dispatch
202def eigvalsh(x: B.JAXNumeric):
203 """
204 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
205 """
206 return jnp.linalg.eigvalsh(x)
209@dispatch
210def reciprocal_no_nan(x: B.JAXNumeric):
211 """
212 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
213 """
214 x_is_zero = jnp.equal(x, 0.0)
215 safe_x = jnp.where(x_is_zero, 1.0, x)
216 return jnp.where(x_is_zero, 0.0, jnp.reciprocal(safe_x))
219@dispatch
220def complex_conj(x: B.JAXNumeric):
221 """
222 Return complex conjugate
223 """
224 return jnp.conj(x)
227@dispatch
228def logical_xor(x1: B.JAXNumeric, x2: B.JAXNumeric):
229 """
230 Return logical XOR of two arrays.
231 """
232 return jnp.logical_xor(x1, x2)
235@dispatch
236def count_nonzero(x: B.JAXNumeric, axis=None):
237 """
238 Count non-zero elements in an array.
239 """
240 return jnp.count_nonzero(x, axis=axis)
243@dispatch
244def dtype_bool(reference: B.JAXRandomState): # type: ignore
245 """
246 Return `bool` dtype of a backend based on the reference.
247 """
248 return jnp.bool_
251@dispatch
252def bool_like(reference: B.JAXRandomState):
253 """
254 Return the type of the reference if it is of boolean type.
255 Otherwise return `bool` dtype of a backend based on the reference.
256 """
257 reference_dtype = reference.dtype
258 if jnp.issubdtype(reference_dtype, jnp.bool_):
259 return convert(
260 reference_dtype, B.JAXDType
261 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one.
262 else:
263 return jnp.bool_