Coverage for geometric_kernels / lab_extras / jax / extras.py: 92%
103 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
1from typing import TypeAlias
3import jax.numpy as jnp
4import lab as B
5from lab import dispatch
6from plum import convert
8_Numeric: TypeAlias = B.Number | B.JAXNumeric
11@dispatch
12def take_along_axis(a: _Numeric | B.Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore
13 """
14 Gathers elements of `a` along `axis` at `index` locations.
15 """
16 if not isinstance(a, jnp.ndarray):
17 a = jnp.array(a)
18 return jnp.take_along_axis(a, index, axis=axis)
21@dispatch
22def from_numpy(_: B.JAXNumeric, b: list | B.NPNumeric | B.Number | B.JAXNumeric): # type: ignore
23 """
24 Converts the array `b` to a tensor of the same backend as `a`
25 """
26 return jnp.array(b)
29@dispatch
30def trapz(y: B.JAXNumeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore
31 """
32 Integrate along the given axis using the trapezoidal rule.
33 """
34 return jnp.trapezoid(y, x, dx, axis)
37@dispatch
38def logspace(start: B.JAXNumeric, stop: _Numeric, num: int = 50): # type: ignore
39 """
40 Return numbers spaced evenly on a log scale.
41 """
42 return jnp.logspace(start, stop, num)
45@dispatch
46def degree(a: B.JAXNumeric): # type: ignore
47 """
48 Given an adjacency matrix `a`, return a diagonal matrix
49 with the col-sums of `a` as main diagonal - this is the
50 degree matrix representing the number of nodes each node
51 is connected to.
52 """
53 degrees = a.sum(axis=0) # type: ignore
54 return jnp.diag(degrees)
57@dispatch
58def eigenpairs(L: B.JAXNumeric, k: int):
59 """
60 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues
61 of a symmetric positive semi-definite matrix `L`.
62 """
63 l, u = jnp.linalg.eigh(L)
64 return l[:k], u[:, :k]
67@dispatch
68def set_value(a: B.JAXNumeric, index: int, value: float):
69 """
70 Set a[index] = value.
71 This operation is not done in place and a new array is returned.
72 """
73 a = a.at[index].set(value)
74 return a
77@dispatch
78def dtype_double(reference: B.JAXRandomState): # type: ignore
79 """
80 Return `double` dtype of a backend based on the reference.
81 """
82 return jnp.float64
85@dispatch
86def float_like(reference: B.JAXNumeric):
87 """
88 Return the type of the reference if it is a floating point type.
89 Otherwise return `double` dtype of a backend based on the reference.
90 """
91 reference_dtype = reference.dtype
92 if jnp.issubdtype(reference_dtype, jnp.floating):
93 return convert(
94 reference_dtype, B.JAXDType
95 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one.
96 else:
97 return jnp.float64
100@dispatch
101def dtype_integer(reference: B.JAXRandomState): # type: ignore
102 """
103 Return `int` dtype of a backend based on the reference.
104 """
105 return jnp.int32
108@dispatch
109def int_like(reference: B.JAXNumeric):
110 reference_dtype = reference.dtype
111 if jnp.issubdtype(reference_dtype, jnp.integer):
112 return convert(
113 reference_dtype, B.JAXDType
114 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one.
115 else:
116 return jnp.int32
119@dispatch
120def get_random_state(key: B.JAXRandomState):
121 """
122 Return the random state of a random generator.
124 :param key:
125 The random generator of type `B.JAXRandomState`.
126 """
127 return key
130@dispatch
131def restore_random_state(key: B.JAXRandomState, state):
132 """
133 Set the random state of a random generator. Return the new random
134 generator with state `state`.
136 :param key:
137 The random generator of type `B.JAXRandomState`.
138 :param state:
139 The new random state of the random generator.
140 """
141 return state
144@dispatch
145def create_complex(real: _Numeric, imag: B.JAXNumeric):
146 """
147 Return a complex number with the given real and imaginary parts using jax.
149 :param real:
150 float, real part of the complex number.
151 :param imag:
152 float, imaginary part of the complex number.
153 """
154 complex_num = real + 1j * imag
155 return complex_num
158@dispatch
159def complex_like(reference: B.JAXNumeric):
160 """
161 Return `complex` dtype of a backend based on the reference.
162 """
163 return B.promote_dtypes(jnp.complex64, reference.dtype)
166@dispatch
167def is_complex(reference: B.JAXNumeric):
168 """
169 Return True if reference of `complex` dtype.
170 """
171 return (B.dtype(reference) == jnp.complex64) or (
172 B.dtype(reference) == jnp.complex128
173 )
176@dispatch
177def cumsum(x: B.JAXNumeric, axis=None):
178 """
179 Return cumulative sum (optionally along axis)
180 """
181 return jnp.cumsum(x, axis=axis)
184@dispatch
185def qr(x: B.JAXNumeric, mode="reduced"):
186 """
187 Return a QR decomposition of a matrix x.
188 """
189 Q, R = jnp.linalg.qr(x, mode=mode)
190 return Q, R
193@dispatch
194def slogdet(x: B.JAXNumeric):
195 """
196 Return the sign and log-determinant of a matrix x.
197 """
198 sign, logdet = jnp.linalg.slogdet(x)
199 return sign, logdet
202@dispatch
203def eigvalsh(x: B.JAXNumeric):
204 """
205 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
206 """
207 return jnp.linalg.eigvalsh(x)
210@dispatch
211def reciprocal_no_nan(x: B.JAXNumeric):
212 """
213 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
214 """
215 x_is_zero = jnp.equal(x, 0.0)
216 safe_x = jnp.where(x_is_zero, 1.0, x)
217 return jnp.where(x_is_zero, 0.0, jnp.reciprocal(safe_x))
220@dispatch
221def complex_conj(x: B.JAXNumeric):
222 """
223 Return complex conjugate
224 """
225 return jnp.conj(x)
228@dispatch
229def logical_xor(x1: B.JAXNumeric, x2: B.JAXNumeric):
230 """
231 Return logical XOR of two arrays.
232 """
233 return jnp.logical_xor(x1, x2)
236@dispatch
237def count_nonzero(x: B.JAXNumeric, axis=None):
238 """
239 Count non-zero elements in an array.
240 """
241 return jnp.count_nonzero(x, axis=axis)
244@dispatch
245def dtype_bool(reference: B.JAXRandomState): # type: ignore
246 """
247 Return `bool` dtype of a backend based on the reference.
248 """
249 return jnp.bool_
252@dispatch
253def bool_like(reference: B.JAXRandomState):
254 """
255 Return the type of the reference if it is of boolean type.
256 Otherwise return `bool` dtype of a backend based on the reference.
257 """
258 reference_dtype = reference.dtype
259 if jnp.issubdtype(reference_dtype, jnp.bool_):
260 return convert(
261 reference_dtype, B.JAXDType
262 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one.
263 else:
264 return jnp.bool_