Coverage for geometric_kernels / lab_extras / jax / extras.py: 92%

103 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 15:49 +0000

1from typing import TypeAlias 

2 

3import jax.numpy as jnp 

4import lab as B 

5from lab import dispatch 

6from plum import convert 

7 

8_Numeric: TypeAlias = B.Number | B.JAXNumeric 

9 

10 

11@dispatch 

12def take_along_axis(a: _Numeric | B.Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore 

13 """ 

14 Gathers elements of `a` along `axis` at `index` locations. 

15 """ 

16 if not isinstance(a, jnp.ndarray): 

17 a = jnp.array(a) 

18 return jnp.take_along_axis(a, index, axis=axis) 

19 

20 

21@dispatch 

22def from_numpy(_: B.JAXNumeric, b: list | B.NPNumeric | B.Number | B.JAXNumeric): # type: ignore 

23 """ 

24 Converts the array `b` to a tensor of the same backend as `a` 

25 """ 

26 return jnp.array(b) 

27 

28 

29@dispatch 

30def trapz(y: B.JAXNumeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore 

31 """ 

32 Integrate along the given axis using the trapezoidal rule. 

33 """ 

34 return jnp.trapezoid(y, x, dx, axis) 

35 

36 

37@dispatch 

38def logspace(start: B.JAXNumeric, stop: _Numeric, num: int = 50): # type: ignore 

39 """ 

40 Return numbers spaced evenly on a log scale. 

41 """ 

42 return jnp.logspace(start, stop, num) 

43 

44 

45@dispatch 

46def degree(a: B.JAXNumeric): # type: ignore 

47 """ 

48 Given an adjacency matrix `a`, return a diagonal matrix 

49 with the col-sums of `a` as main diagonal - this is the 

50 degree matrix representing the number of nodes each node 

51 is connected to. 

52 """ 

53 degrees = a.sum(axis=0) # type: ignore 

54 return jnp.diag(degrees) 

55 

56 

57@dispatch 

58def eigenpairs(L: B.JAXNumeric, k: int): 

59 """ 

60 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues 

61 of a symmetric positive semi-definite matrix `L`. 

62 """ 

63 l, u = jnp.linalg.eigh(L) 

64 return l[:k], u[:, :k] 

65 

66 

67@dispatch 

68def set_value(a: B.JAXNumeric, index: int, value: float): 

69 """ 

70 Set a[index] = value. 

71 This operation is not done in place and a new array is returned. 

72 """ 

73 a = a.at[index].set(value) 

74 return a 

75 

76 

77@dispatch 

78def dtype_double(reference: B.JAXRandomState): # type: ignore 

79 """ 

80 Return `double` dtype of a backend based on the reference. 

81 """ 

82 return jnp.float64 

83 

84 

85@dispatch 

86def float_like(reference: B.JAXNumeric): 

87 """ 

88 Return the type of the reference if it is a floating point type. 

89 Otherwise return `double` dtype of a backend based on the reference. 

90 """ 

91 reference_dtype = reference.dtype 

92 if jnp.issubdtype(reference_dtype, jnp.floating): 

93 return convert( 

94 reference_dtype, B.JAXDType 

95 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one. 

96 else: 

97 return jnp.float64 

98 

99 

100@dispatch 

101def dtype_integer(reference: B.JAXRandomState): # type: ignore 

102 """ 

103 Return `int` dtype of a backend based on the reference. 

104 """ 

105 return jnp.int32 

106 

107 

108@dispatch 

109def int_like(reference: B.JAXNumeric): 

110 reference_dtype = reference.dtype 

111 if jnp.issubdtype(reference_dtype, jnp.integer): 

112 return convert( 

113 reference_dtype, B.JAXDType 

114 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one. 

115 else: 

116 return jnp.int32 

117 

118 

119@dispatch 

120def get_random_state(key: B.JAXRandomState): 

121 """ 

122 Return the random state of a random generator. 

123 

124 :param key: 

125 The random generator of type `B.JAXRandomState`. 

126 """ 

127 return key 

128 

129 

130@dispatch 

131def restore_random_state(key: B.JAXRandomState, state): 

132 """ 

133 Set the random state of a random generator. Return the new random 

134 generator with state `state`. 

135 

136 :param key: 

137 The random generator of type `B.JAXRandomState`. 

138 :param state: 

139 The new random state of the random generator. 

140 """ 

141 return state 

142 

143 

144@dispatch 

145def create_complex(real: _Numeric, imag: B.JAXNumeric): 

146 """ 

147 Return a complex number with the given real and imaginary parts using jax. 

148 

149 :param real: 

150 float, real part of the complex number. 

151 :param imag: 

152 float, imaginary part of the complex number. 

153 """ 

154 complex_num = real + 1j * imag 

155 return complex_num 

156 

157 

158@dispatch 

159def complex_like(reference: B.JAXNumeric): 

160 """ 

161 Return `complex` dtype of a backend based on the reference. 

162 """ 

163 return B.promote_dtypes(jnp.complex64, reference.dtype) 

164 

165 

166@dispatch 

167def is_complex(reference: B.JAXNumeric): 

168 """ 

169 Return True if reference of `complex` dtype. 

170 """ 

171 return (B.dtype(reference) == jnp.complex64) or ( 

172 B.dtype(reference) == jnp.complex128 

173 ) 

174 

175 

176@dispatch 

177def cumsum(x: B.JAXNumeric, axis=None): 

178 """ 

179 Return cumulative sum (optionally along axis) 

180 """ 

181 return jnp.cumsum(x, axis=axis) 

182 

183 

184@dispatch 

185def qr(x: B.JAXNumeric, mode="reduced"): 

186 """ 

187 Return a QR decomposition of a matrix x. 

188 """ 

189 Q, R = jnp.linalg.qr(x, mode=mode) 

190 return Q, R 

191 

192 

193@dispatch 

194def slogdet(x: B.JAXNumeric): 

195 """ 

196 Return the sign and log-determinant of a matrix x. 

197 """ 

198 sign, logdet = jnp.linalg.slogdet(x) 

199 return sign, logdet 

200 

201 

202@dispatch 

203def eigvalsh(x: B.JAXNumeric): 

204 """ 

205 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

206 """ 

207 return jnp.linalg.eigvalsh(x) 

208 

209 

210@dispatch 

211def reciprocal_no_nan(x: B.JAXNumeric): 

212 """ 

213 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

214 """ 

215 x_is_zero = jnp.equal(x, 0.0) 

216 safe_x = jnp.where(x_is_zero, 1.0, x) 

217 return jnp.where(x_is_zero, 0.0, jnp.reciprocal(safe_x)) 

218 

219 

220@dispatch 

221def complex_conj(x: B.JAXNumeric): 

222 """ 

223 Return complex conjugate 

224 """ 

225 return jnp.conj(x) 

226 

227 

228@dispatch 

229def logical_xor(x1: B.JAXNumeric, x2: B.JAXNumeric): 

230 """ 

231 Return logical XOR of two arrays. 

232 """ 

233 return jnp.logical_xor(x1, x2) 

234 

235 

236@dispatch 

237def count_nonzero(x: B.JAXNumeric, axis=None): 

238 """ 

239 Count non-zero elements in an array. 

240 """ 

241 return jnp.count_nonzero(x, axis=axis) 

242 

243 

244@dispatch 

245def dtype_bool(reference: B.JAXRandomState): # type: ignore 

246 """ 

247 Return `bool` dtype of a backend based on the reference. 

248 """ 

249 return jnp.bool_ 

250 

251 

252@dispatch 

253def bool_like(reference: B.JAXRandomState): 

254 """ 

255 Return the type of the reference if it is of boolean type. 

256 Otherwise return `bool` dtype of a backend based on the reference. 

257 """ 

258 reference_dtype = reference.dtype 

259 if jnp.issubdtype(reference_dtype, jnp.bool_): 

260 return convert( 

261 reference_dtype, B.JAXDType 

262 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one. 

263 else: 

264 return jnp.bool_