Coverage for geometric_kernels/lab_extras/jax/extras.py: 93%

103 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import jax.numpy as jnp 

2import lab as B 

3from beartype.typing import List 

4from lab import dispatch 

5from plum import Union, convert 

6 

7_Numeric = Union[B.Number, B.JAXNumeric] 

8 

9 

10@dispatch 

11def take_along_axis(a: Union[_Numeric, B.Numeric], index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore 

12 """ 

13 Gathers elements of `a` along `axis` at `index` locations. 

14 """ 

15 if not isinstance(a, jnp.ndarray): 

16 a = jnp.array(a) 

17 return jnp.take_along_axis(a, index, axis=axis) 

18 

19 

20@dispatch 

21def from_numpy(_: B.JAXNumeric, b: Union[List, B.NPNumeric, B.Number, B.JAXNumeric]): # type: ignore 

22 """ 

23 Converts the array `b` to a tensor of the same backend as `a` 

24 """ 

25 return jnp.array(b) 

26 

27 

28@dispatch 

29def trapz(y: B.JAXNumeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore 

30 """ 

31 Integrate along the given axis using the trapezoidal rule. 

32 """ 

33 return jnp.trapezoid(y, x, dx, axis) 

34 

35 

36@dispatch 

37def logspace(start: B.JAXNumeric, stop: _Numeric, num: int = 50): # type: ignore 

38 """ 

39 Return numbers spaced evenly on a log scale. 

40 """ 

41 return jnp.logspace(start, stop, num) 

42 

43 

44@dispatch 

45def degree(a: B.JAXNumeric): # type: ignore 

46 """ 

47 Given an adjacency matrix `a`, return a diagonal matrix 

48 with the col-sums of `a` as main diagonal - this is the 

49 degree matrix representing the number of nodes each node 

50 is connected to. 

51 """ 

52 degrees = a.sum(axis=0) # type: ignore 

53 return jnp.diag(degrees) 

54 

55 

56@dispatch 

57def eigenpairs(L: B.JAXNumeric, k: int): 

58 """ 

59 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues 

60 of a symmetric positive semi-definite matrix `L`. 

61 """ 

62 l, u = jnp.linalg.eigh(L) 

63 return l[:k], u[:, :k] 

64 

65 

66@dispatch 

67def set_value(a: B.JAXNumeric, index: int, value: float): 

68 """ 

69 Set a[index] = value. 

70 This operation is not done in place and a new array is returned. 

71 """ 

72 a = a.at[index].set(value) 

73 return a 

74 

75 

76@dispatch 

77def dtype_double(reference: B.JAXRandomState): # type: ignore 

78 """ 

79 Return `double` dtype of a backend based on the reference. 

80 """ 

81 return jnp.float64 

82 

83 

84@dispatch 

85def float_like(reference: B.JAXNumeric): 

86 """ 

87 Return the type of the reference if it is a floating point type. 

88 Otherwise return `double` dtype of a backend based on the reference. 

89 """ 

90 reference_dtype = reference.dtype 

91 if jnp.issubdtype(reference_dtype, jnp.floating): 

92 return convert( 

93 reference_dtype, B.JAXDType 

94 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one. 

95 else: 

96 return jnp.float64 

97 

98 

99@dispatch 

100def dtype_integer(reference: B.JAXRandomState): # type: ignore 

101 """ 

102 Return `int` dtype of a backend based on the reference. 

103 """ 

104 return jnp.int32 

105 

106 

107@dispatch 

108def int_like(reference: B.JAXNumeric): 

109 reference_dtype = reference.dtype 

110 if jnp.issubdtype(reference_dtype, jnp.integer): 

111 return convert( 

112 reference_dtype, B.JAXDType 

113 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one. 

114 else: 

115 return jnp.int32 

116 

117 

118@dispatch 

119def get_random_state(key: B.JAXRandomState): 

120 """ 

121 Return the random state of a random generator. 

122 

123 :param key: 

124 The random generator of type `B.JAXRandomState`. 

125 """ 

126 return key 

127 

128 

129@dispatch 

130def restore_random_state(key: B.JAXRandomState, state): 

131 """ 

132 Set the random state of a random generator. Return the new random 

133 generator with state `state`. 

134 

135 :param key: 

136 The random generator of type `B.JAXRandomState`. 

137 :param state: 

138 The new random state of the random generator. 

139 """ 

140 return state 

141 

142 

143@dispatch 

144def create_complex(real: _Numeric, imag: B.JAXNumeric): 

145 """ 

146 Return a complex number with the given real and imaginary parts using jax. 

147 

148 :param real: 

149 float, real part of the complex number. 

150 :param imag: 

151 float, imaginary part of the complex number. 

152 """ 

153 complex_num = real + 1j * imag 

154 return complex_num 

155 

156 

157@dispatch 

158def complex_like(reference: B.JAXNumeric): 

159 """ 

160 Return `complex` dtype of a backend based on the reference. 

161 """ 

162 return B.promote_dtypes(jnp.complex64, reference.dtype) 

163 

164 

165@dispatch 

166def is_complex(reference: B.JAXNumeric): 

167 """ 

168 Return True if reference of `complex` dtype. 

169 """ 

170 return (B.dtype(reference) == jnp.complex64) or ( 

171 B.dtype(reference) == jnp.complex128 

172 ) 

173 

174 

175@dispatch 

176def cumsum(x: B.JAXNumeric, axis=None): 

177 """ 

178 Return cumulative sum (optionally along axis) 

179 """ 

180 return jnp.cumsum(x, axis=axis) 

181 

182 

183@dispatch 

184def qr(x: B.JAXNumeric, mode="reduced"): 

185 """ 

186 Return a QR decomposition of a matrix x. 

187 """ 

188 Q, R = jnp.linalg.qr(x, mode=mode) 

189 return Q, R 

190 

191 

192@dispatch 

193def slogdet(x: B.JAXNumeric): 

194 """ 

195 Return the sign and log-determinant of a matrix x. 

196 """ 

197 sign, logdet = jnp.linalg.slogdet(x) 

198 return sign, logdet 

199 

200 

201@dispatch 

202def eigvalsh(x: B.JAXNumeric): 

203 """ 

204 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

205 """ 

206 return jnp.linalg.eigvalsh(x) 

207 

208 

209@dispatch 

210def reciprocal_no_nan(x: B.JAXNumeric): 

211 """ 

212 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

213 """ 

214 x_is_zero = jnp.equal(x, 0.0) 

215 safe_x = jnp.where(x_is_zero, 1.0, x) 

216 return jnp.where(x_is_zero, 0.0, jnp.reciprocal(safe_x)) 

217 

218 

219@dispatch 

220def complex_conj(x: B.JAXNumeric): 

221 """ 

222 Return complex conjugate 

223 """ 

224 return jnp.conj(x) 

225 

226 

227@dispatch 

228def logical_xor(x1: B.JAXNumeric, x2: B.JAXNumeric): 

229 """ 

230 Return logical XOR of two arrays. 

231 """ 

232 return jnp.logical_xor(x1, x2) 

233 

234 

235@dispatch 

236def count_nonzero(x: B.JAXNumeric, axis=None): 

237 """ 

238 Count non-zero elements in an array. 

239 """ 

240 return jnp.count_nonzero(x, axis=axis) 

241 

242 

243@dispatch 

244def dtype_bool(reference: B.JAXRandomState): # type: ignore 

245 """ 

246 Return `bool` dtype of a backend based on the reference. 

247 """ 

248 return jnp.bool_ 

249 

250 

251@dispatch 

252def bool_like(reference: B.JAXRandomState): 

253 """ 

254 Return the type of the reference if it is of boolean type. 

255 Otherwise return `bool` dtype of a backend based on the reference. 

256 """ 

257 reference_dtype = reference.dtype 

258 if jnp.issubdtype(reference_dtype, jnp.bool_): 

259 return convert( 

260 reference_dtype, B.JAXDType 

261 ) # JAX .dtype returns a NumPy data type. This converts it to a JAX one. 

262 else: 

263 return jnp.bool_