Coverage for geometric_kernels/lab_extras/numpy/extras.py: 92%

102 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import lab as B 

2import numpy as np 

3from beartype.typing import Any, List, Optional 

4from lab import dispatch 

5from plum import Union 

6from scipy.sparse import spmatrix 

7 

8_Numeric = Union[B.Number, B.NPNumeric] 

9 

10 

11@dispatch 

12def take_along_axis(a: _Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore 

13 """ 

14 Gathers elements of `a` along `axis` at `index` locations. 

15 """ 

16 return np.take_along_axis(a, index, axis=axis) 

17 

18 

19@dispatch 

20def from_numpy(_: B.NPNumeric, b: Union[List, B.NPNumeric, B.Number]): # type: ignore 

21 """ 

22 Converts the array `b` to a tensor of the same backend as `a` 

23 """ 

24 return np.array(b) 

25 

26 

27@dispatch 

28def trapz(y: _Numeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore 

29 """ 

30 Integrate along the given axis using the composite trapezoidal rule. 

31 """ 

32 return np.trapz(y, x, dx, axis) 

33 

34 

35@dispatch 

36def norm(x: _Numeric, ord: Optional[Any] = None, axis: Optional[int] = None): # type: ignore 

37 """ 

38 Matrix or vector norm. 

39 """ 

40 return np.linalg.norm(x, ord=ord, axis=axis) 

41 

42 

43@dispatch 

44def logspace(start: _Numeric, stop: _Numeric, num: int = 50, base: _Numeric = 50.0): # type: ignore 

45 """ 

46 Return numbers spaced evenly on a log scale. 

47 """ 

48 return np.logspace(start, stop, num, base=base) 

49 

50 

51@dispatch 

52def degree(a: _Numeric): # type: ignore 

53 """ 

54 Given an adjacency matrix `a`, return a diagonal matrix 

55 with the col-sums of `a` as main diagonal - this is the 

56 degree matrix representing the number of nodes each node 

57 is connected to. 

58 """ 

59 degrees = a.sum(axis=0) # type: ignore 

60 return np.diag(degrees) 

61 

62 

63@dispatch 

64def dtype_double(reference: B.NPRandomState): # type: ignore 

65 """ 

66 Return `double` dtype of a backend based on the reference. 

67 """ 

68 return np.float64 

69 

70 

71@dispatch 

72def float_like(reference: B.NPNumeric): 

73 """ 

74 Return the type of the reference if it is a floating point type. 

75 Otherwise return `double` dtype of a backend based on the reference. 

76 """ 

77 reference_dtype = reference.dtype 

78 if np.issubdtype(reference_dtype, np.floating): 

79 return reference_dtype 

80 else: 

81 return np.float64 

82 

83 

84@dispatch 

85def dtype_integer(reference: B.NPRandomState): # type: ignore 

86 """ 

87 Return `int` dtype of a backend based on the reference. 

88 """ 

89 return np.int32 

90 

91 

92@dispatch 

93def int_like(reference: B.NPNumeric): 

94 reference_dtype = reference.dtype 

95 if np.issubdtype(reference_dtype, np.integer): 

96 return reference_dtype 

97 else: 

98 return np.int32 

99 

100 

101@dispatch 

102def get_random_state(key: B.NPRandomState): 

103 """ 

104 Return the random state of a random generator. 

105 

106 :param key: 

107 The random generator of type `B.NPRandomState`. 

108 """ 

109 return key.get_state() 

110 

111 

112@dispatch 

113def restore_random_state(key: B.NPRandomState, state): 

114 """ 

115 Set the random state of a random generator. Return the new random 

116 generator with state `state`. 

117 

118 :param key: 

119 The random generator of type `B.NPRandomState`. 

120 :param state: 

121 The new random state of the random generator. 

122 """ 

123 gen = np.random.RandomState() 

124 gen.set_state(state) 

125 return gen 

126 

127 

128@dispatch 

129def create_complex(real: _Numeric, imag: _Numeric): 

130 """ 

131 Return a complex number with the given real and imaginary parts using numpy. 

132 

133 :param real: 

134 float, real part of the complex number. 

135 :param imag: 

136 float, imaginary part of the complex number. 

137 """ 

138 complex_num = real + 1j * imag 

139 return complex_num 

140 

141 

142@dispatch 

143def complex_like(reference: B.NPNumeric): 

144 """ 

145 Return `complex` dtype of a backend based on the reference. 

146 """ 

147 return B.promote_dtypes(np.complex64, reference.dtype) 

148 

149 

150@dispatch 

151def is_complex(reference: B.NPNumeric): 

152 """ 

153 Return True if reference of `complex` dtype. 

154 """ 

155 return (B.dtype(reference) == np.complex64) or (B.dtype(reference) == np.complex128) 

156 

157 

158@dispatch 

159def cumsum(a: _Numeric, axis=None): 

160 """ 

161 Return cumulative sum (optionally along axis) 

162 """ 

163 return np.cumsum(a, axis=axis) 

164 

165 

166@dispatch 

167def qr(x: _Numeric, mode="reduced"): 

168 """ 

169 Return a QR decomposition of a matrix x. 

170 """ 

171 Q, R = np.linalg.qr(x, mode=mode) 

172 return Q, R 

173 

174 

175@dispatch 

176def slogdet(x: _Numeric): 

177 """ 

178 Return the sign and log-determinant of a matrix x. 

179 """ 

180 sign, logdet = np.linalg.slogdet(x) 

181 return sign, logdet 

182 

183 

184@dispatch 

185def eigvalsh(x: _Numeric): 

186 """ 

187 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

188 """ 

189 return np.linalg.eigvalsh(x, UPLO="U") 

190 

191 

192@dispatch 

193def reciprocal_no_nan(x: B.NPNumeric): 

194 """ 

195 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

196 """ 

197 x_is_zero = np.equal(x, 0.0) 

198 safe_x = np.where(x_is_zero, 1.0, x) 

199 return np.where(x_is_zero, 0.0, np.reciprocal(safe_x)) 

200 

201 

202@dispatch # type: ignore[no-redef] 

203def reciprocal_no_nan(x: spmatrix): 

204 """ 

205 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

206 """ 

207 return x._with_data(reciprocal_no_nan(x._deduped_data().copy()), copy=True) 

208 

209 

210@dispatch 

211def complex_conj(x: B.NPNumeric): 

212 """ 

213 Return complex conjugate 

214 """ 

215 return np.conjugate(x) 

216 

217 

218@dispatch 

219def logical_xor(x1: B.NPNumeric, x2: B.NPNumeric): 

220 """ 

221 Return logical XOR of two arrays. 

222 """ 

223 return np.logical_xor(x1, x2) 

224 

225 

226@dispatch 

227def count_nonzero(x: B.NPNumeric, axis=None): 

228 """ 

229 Count non-zero elements in an array. 

230 """ 

231 return np.count_nonzero(x, axis=axis) 

232 

233 

234@dispatch 

235def dtype_bool(reference: B.NPRandomState): # type: ignore 

236 """ 

237 Return `bool` dtype of a backend based on the reference. 

238 """ 

239 return np.bool_ 

240 

241 

242@dispatch 

243def bool_like(reference: B.NPNumeric): 

244 """ 

245 Return the type of the reference if it is of boolean type. 

246 Otherwise return `bool` dtype of a backend based on the reference. 

247 """ 

248 reference_dtype = reference.dtype 

249 if np.issubdtype(reference_dtype, np.bool_): 

250 return reference_dtype 

251 else: 

252 return np.bool_