Coverage for geometric_kernels / lab_extras / torch / extras.py: 92%

109 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 15:49 +0000

1from typing import TypeAlias 

2 

3import lab as B 

4import torch 

5from beartype.typing import Any 

6from lab import dispatch 

7 

8_Numeric: TypeAlias = B.Number | B.TorchNumeric 

9 

10 

11@dispatch 

12def take_along_axis(a: _Numeric | B.Numeric, index: B.TorchNumeric, axis: int = 0) -> _Numeric: # type: ignore 

13 """ 

14 Gathers elements of `a` along `axis` at `index` locations. 

15 """ 

16 if not torch.is_tensor(a): 

17 a = torch.tensor(a).to(index.device) # type: ignore 

18 return torch.take_along_dim(a, index.long(), axis) # long is required by torch 

19 

20 

21@dispatch 

22def from_numpy( 

23 a: B.TorchNumeric, b: list | B.Number | B.NPNumeric | B.TorchNumeric 

24): # type: ignore 

25 """ 

26 Converts the array `b` to a tensor of the same backend as `a` 

27 """ 

28 if not torch.is_tensor(b): 

29 b = torch.tensor(b.copy()).to(a.device) # type: ignore 

30 return b 

31 

32 

33@dispatch 

34def trapz(y: B.TorchNumeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore 

35 """ 

36 Integrate along the given axis using the trapezoidal rule. 

37 """ 

38 return torch.trapz(y, x, dim=axis) 

39 

40 

41@dispatch 

42def norm(x: _Numeric, ord: Any | None = None, axis: int | None = None): # type: ignore 

43 """ 

44 Matrix or vector norm. 

45 """ 

46 return torch.linalg.norm(x, ord=ord, dim=axis) 

47 

48 

49@dispatch 

50def logspace(start: B.TorchNumeric, stop: B.TorchNumeric, num: int = 50, base: _Numeric = 10.0): # type: ignore 

51 """ 

52 Return numbers spaced evenly on a log scale. 

53 """ 

54 return torch.logspace(start.item(), stop.item(), num, base) 

55 

56 

57@dispatch 

58def degree(a: B.TorchNumeric): # type: ignore 

59 """ 

60 Given an adjacency matrix `a`, return a diagonal matrix 

61 with the col-sums of `a` as main diagonal - this is the 

62 degree matrix representing the number of nodes each node 

63 is connected to. 

64 """ 

65 degrees = a.sum(axis=0) # type: ignore 

66 return torch.diag(degrees) 

67 

68 

69@dispatch 

70def eigenpairs(L: B.TorchNumeric, k: int): 

71 """ 

72 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues 

73 of a symmetric positive semi-definite matrix `L`. 

74 

75 TODO(AR): Replace with torch.lobpcg after sparse matrices are supported 

76 by torch. 

77 """ 

78 l, u = torch.linalg.eigh(L) 

79 return l[:k], u[:, :k] 

80 

81 

82@dispatch 

83def set_value(a: B.TorchNumeric, index: int, value: float): 

84 """ 

85 Set a[index] = value. 

86 This operation is not done in place and a new array is returned. 

87 """ 

88 a = a.clone() 

89 a[index] = value 

90 return a 

91 

92 

93@dispatch 

94def dtype_double(reference: B.TorchRandomState): # type: ignore 

95 """ 

96 Return `double` dtype of a backend based on the reference. 

97 """ 

98 return torch.double 

99 

100 

101@dispatch 

102def float_like(reference: B.TorchNumeric): 

103 """ 

104 Return the type of the reference if it is a floating point type. 

105 Otherwise return `double` dtype of a backend based on the reference. 

106 """ 

107 if torch.is_floating_point(reference): 

108 return B.dtype(reference) 

109 else: 

110 return torch.float64 

111 

112 

113@dispatch 

114def dtype_integer(reference: B.TorchRandomState): # type: ignore 

115 """ 

116 Return `int` dtype of a backend based on the reference. 

117 """ 

118 return torch.int 

119 

120 

121@dispatch 

122def int_like(reference: B.TorchNumeric): 

123 reference_dtype = reference.dtype 

124 if reference_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: 

125 return reference_dtype 

126 else: 

127 return torch.int32 

128 

129 

130@dispatch 

131def get_random_state(key: B.TorchRandomState): 

132 """ 

133 Return the random state of a random generator. 

134 

135 :param key: the random generator of type `B.TorchRandomState`. 

136 """ 

137 return key.get_state() 

138 

139 

140@dispatch 

141def restore_random_state(key: B.TorchRandomState, state): 

142 """ 

143 Set the random state of a random generator. Return the new random 

144 generator with state `state`. 

145 

146 :param key: 

147 The random generator of type `B.TorchRandomState`. 

148 :param state: 

149 The new random state of the random generator. 

150 """ 

151 gen = torch.Generator() 

152 gen.set_state(state) 

153 return gen 

154 

155 

156@dispatch 

157def create_complex(real: _Numeric, imag: B.TorchNumeric): 

158 """ 

159 Return a complex number with the given real and imaginary parts using pytorch. 

160 

161 :param real: 

162 float, real part of the complex number. 

163 :param imag: 

164 float, imaginary part of the complex number. 

165 """ 

166 complex_num = real + 1j * imag 

167 return complex_num 

168 

169 

170@dispatch 

171def complex_like(reference: B.TorchNumeric): 

172 """ 

173 Return `complex` dtype of a backend based on the reference. 

174 """ 

175 return B.promote_dtypes(torch.cfloat, reference.dtype) 

176 

177 

178@dispatch 

179def is_complex(reference: B.TorchNumeric): 

180 """ 

181 Return True if reference of `complex` dtype. 

182 """ 

183 return (B.dtype(reference) == torch.cfloat) or (B.dtype(reference) == torch.cdouble) 

184 

185 

186@dispatch 

187def cumsum(x: B.TorchNumeric, axis=None): 

188 """ 

189 Return cumulative sum (optionally along axis) 

190 """ 

191 return torch.cumsum(x, dim=axis) 

192 

193 

194@dispatch 

195def qr(x: B.TorchNumeric, mode="reduced"): 

196 """ 

197 Return a QR decomposition of a matrix x. 

198 """ 

199 Q, R = torch.linalg.qr(x, mode=mode) 

200 return Q, R 

201 

202 

203@dispatch 

204def slogdet(x: B.TorchNumeric): 

205 """ 

206 Return the sign and log-determinant of a matrix x. 

207 """ 

208 sign, logdet = torch.slogdet(x) 

209 return sign, logdet 

210 

211 

212@dispatch 

213def eigvalsh(x: B.TorchNumeric): 

214 """ 

215 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

216 """ 

217 return torch.linalg.eigvalsh(x) 

218 

219 

220@dispatch 

221def reciprocal_no_nan(x: B.TorchNumeric): 

222 """ 

223 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

224 """ 

225 safe_x = torch.where(x == 0.0, 1.0, x) 

226 return torch.where(x == 0.0, 0.0, torch.reciprocal(safe_x)) 

227 

228 

229@dispatch 

230def complex_conj(x: B.TorchNumeric): 

231 """ 

232 Return complex conjugate 

233 """ 

234 return torch.conj(x) 

235 

236 

237@dispatch 

238def logical_xor(x1: B.TorchNumeric, x2: B.TorchNumeric): 

239 """ 

240 Return logical XOR of two arrays. 

241 """ 

242 return torch.logical_xor(x1, x2) 

243 

244 

245@dispatch 

246def count_nonzero(x: B.TorchNumeric, axis=None): 

247 """ 

248 Count non-zero elements in an array. 

249 """ 

250 return torch.count_nonzero(x, dim=axis) 

251 

252 

253@dispatch 

254def dtype_bool(reference: B.TorchRandomState): # type: ignore 

255 """ 

256 Return `bool` dtype of a backend based on the reference. 

257 """ 

258 return torch.bool 

259 

260 

261@dispatch 

262def bool_like(reference: B.TorchNumeric): 

263 """ 

264 Return the type of the reference if it is of boolean type. 

265 Otherwise return `bool` dtype of a backend based on the reference. 

266 """ 

267 reference_dtype = reference.dtype 

268 if reference_dtype is torch.bool: 

269 return reference_dtype 

270 else: 

271 return torch.bool