Coverage for geometric_kernels / lab_extras / numpy / extras.py: 91%

102 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 15:49 +0000

1from typing import TypeAlias 

2 

3import lab as B 

4import numpy as np 

5from beartype.typing import Any 

6from lab import dispatch 

7from scipy.sparse import spmatrix 

8 

9_Numeric: TypeAlias = B.Number | B.NPNumeric 

10 

11 

12@dispatch 

13def take_along_axis(a: _Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore 

14 """ 

15 Gathers elements of `a` along `axis` at `index` locations. 

16 """ 

17 return np.take_along_axis(a, index, axis=axis) 

18 

19 

20@dispatch 

21def from_numpy(_: B.NPNumeric, b: list | B.NPNumeric | B.Number): # type: ignore 

22 """ 

23 Converts the array `b` to a tensor of the same backend as `a` 

24 """ 

25 return np.array(b) 

26 

27 

28@dispatch 

29def trapz(y: _Numeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore 

30 """ 

31 Integrate along the given axis using the composite trapezoidal rule. 

32 """ 

33 return np.trapz(y, x, dx, axis) 

34 

35 

36@dispatch 

37def norm(x: _Numeric, ord: Any | None = None, axis: int | None = None): # type: ignore 

38 """ 

39 Matrix or vector norm. 

40 """ 

41 return np.linalg.norm(x, ord=ord, axis=axis) 

42 

43 

44@dispatch 

45def logspace(start: _Numeric, stop: _Numeric, num: int = 50, base: _Numeric = 50.0): # type: ignore 

46 """ 

47 Return numbers spaced evenly on a log scale. 

48 """ 

49 return np.logspace(start, stop, num, base=base) 

50 

51 

52@dispatch 

53def degree(a: _Numeric): # type: ignore 

54 """ 

55 Given an adjacency matrix `a`, return a diagonal matrix 

56 with the col-sums of `a` as main diagonal - this is the 

57 degree matrix representing the number of nodes each node 

58 is connected to. 

59 """ 

60 degrees = a.sum(axis=0) # type: ignore 

61 return np.diag(degrees) 

62 

63 

64@dispatch 

65def dtype_double(reference: B.NPRandomState): # type: ignore 

66 """ 

67 Return `double` dtype of a backend based on the reference. 

68 """ 

69 return np.float64 

70 

71 

72@dispatch 

73def float_like(reference: B.NPNumeric): 

74 """ 

75 Return the type of the reference if it is a floating point type. 

76 Otherwise return `double` dtype of a backend based on the reference. 

77 """ 

78 reference_dtype = reference.dtype 

79 if np.issubdtype(reference_dtype, np.floating): 

80 return reference_dtype 

81 else: 

82 return np.float64 

83 

84 

85@dispatch 

86def dtype_integer(reference: B.NPRandomState): # type: ignore 

87 """ 

88 Return `int` dtype of a backend based on the reference. 

89 """ 

90 return np.int32 

91 

92 

93@dispatch 

94def int_like(reference: B.NPNumeric): 

95 reference_dtype = reference.dtype 

96 if np.issubdtype(reference_dtype, np.integer): 

97 return reference_dtype 

98 else: 

99 return np.int32 

100 

101 

102@dispatch 

103def get_random_state(key: B.NPRandomState): 

104 """ 

105 Return the random state of a random generator. 

106 

107 :param key: 

108 The random generator of type `B.NPRandomState`. 

109 """ 

110 return key.get_state() 

111 

112 

113@dispatch 

114def restore_random_state(key: B.NPRandomState, state): 

115 """ 

116 Set the random state of a random generator. Return the new random 

117 generator with state `state`. 

118 

119 :param key: 

120 The random generator of type `B.NPRandomState`. 

121 :param state: 

122 The new random state of the random generator. 

123 """ 

124 gen = np.random.RandomState() 

125 gen.set_state(state) 

126 return gen 

127 

128 

129@dispatch 

130def create_complex(real: _Numeric, imag: _Numeric): 

131 """ 

132 Return a complex number with the given real and imaginary parts using numpy. 

133 

134 :param real: 

135 float, real part of the complex number. 

136 :param imag: 

137 float, imaginary part of the complex number. 

138 """ 

139 complex_num = real + 1j * imag 

140 return complex_num 

141 

142 

143@dispatch 

144def complex_like(reference: B.NPNumeric): 

145 """ 

146 Return `complex` dtype of a backend based on the reference. 

147 """ 

148 return B.promote_dtypes(np.complex64, reference.dtype) 

149 

150 

151@dispatch 

152def is_complex(reference: B.NPNumeric): 

153 """ 

154 Return True if reference of `complex` dtype. 

155 """ 

156 return (B.dtype(reference) == np.complex64) or (B.dtype(reference) == np.complex128) 

157 

158 

159@dispatch 

160def cumsum(a: _Numeric, axis=None): 

161 """ 

162 Return cumulative sum (optionally along axis) 

163 """ 

164 return np.cumsum(a, axis=axis) 

165 

166 

167@dispatch 

168def qr(x: _Numeric, mode="reduced"): 

169 """ 

170 Return a QR decomposition of a matrix x. 

171 """ 

172 Q, R = np.linalg.qr(x, mode=mode) 

173 return Q, R 

174 

175 

176@dispatch 

177def slogdet(x: _Numeric): 

178 """ 

179 Return the sign and log-determinant of a matrix x. 

180 """ 

181 sign, logdet = np.linalg.slogdet(x) 

182 return sign, logdet 

183 

184 

185@dispatch 

186def eigvalsh(x: _Numeric): 

187 """ 

188 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

189 """ 

190 return np.linalg.eigvalsh(x, UPLO="U") 

191 

192 

193@dispatch 

194def reciprocal_no_nan(x: B.NPNumeric): 

195 """ 

196 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

197 """ 

198 x_is_zero = np.equal(x, 0.0) 

199 safe_x = np.where(x_is_zero, 1.0, x) 

200 return np.where(x_is_zero, 0.0, np.reciprocal(safe_x)) 

201 

202 

203@dispatch # type: ignore[no-redef] 

204def reciprocal_no_nan(x: spmatrix): 

205 """ 

206 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

207 """ 

208 return x._with_data(reciprocal_no_nan(x._deduped_data().copy()), copy=True) 

209 

210 

211@dispatch 

212def complex_conj(x: B.NPNumeric): 

213 """ 

214 Return complex conjugate 

215 """ 

216 return np.conjugate(x) 

217 

218 

219@dispatch 

220def logical_xor(x1: B.NPNumeric, x2: B.NPNumeric): 

221 """ 

222 Return logical XOR of two arrays. 

223 """ 

224 return np.logical_xor(x1, x2) 

225 

226 

227@dispatch 

228def count_nonzero(x: B.NPNumeric, axis=None): 

229 """ 

230 Count non-zero elements in an array. 

231 """ 

232 return np.count_nonzero(x, axis=axis) 

233 

234 

235@dispatch 

236def dtype_bool(reference: B.NPRandomState): # type: ignore 

237 """ 

238 Return `bool` dtype of a backend based on the reference. 

239 """ 

240 return np.bool_ 

241 

242 

243@dispatch 

244def bool_like(reference: B.NPNumeric): 

245 """ 

246 Return the type of the reference if it is of boolean type. 

247 Otherwise return `bool` dtype of a backend based on the reference. 

248 """ 

249 reference_dtype = reference.dtype 

250 if np.issubdtype(reference_dtype, np.bool_): 

251 return reference_dtype 

252 else: 

253 return np.bool_