Coverage for geometric_kernels/lab_extras/torch/extras.py: 93%

109 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import lab as B 

2import torch 

3from beartype.typing import Any, List, Optional 

4from lab import dispatch 

5from plum import Union 

6 

7_Numeric = Union[B.Number, B.TorchNumeric] 

8 

9 

10@dispatch 

11def take_along_axis(a: Union[_Numeric, B.Numeric], index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore 

12 """ 

13 Gathers elements of `a` along `axis` at `index` locations. 

14 """ 

15 if not torch.is_tensor(a): 

16 a = torch.tensor(a).to(index.device) # type: ignore 

17 return torch.take_along_dim(a, index.long(), axis) # long is required by torch 

18 

19 

20@dispatch 

21def from_numpy( 

22 a: B.TorchNumeric, b: Union[List, B.Number, B.NPNumeric, B.TorchNumeric] 

23): # type: ignore 

24 """ 

25 Converts the array `b` to a tensor of the same backend as `a` 

26 """ 

27 if not torch.is_tensor(b): 

28 b = torch.tensor(b.copy()).to(a.device) # type: ignore 

29 return b 

30 

31 

32@dispatch 

33def trapz(y: B.TorchNumeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore 

34 """ 

35 Integrate along the given axis using the trapezoidal rule. 

36 """ 

37 return torch.trapz(y, x, dim=axis) 

38 

39 

40@dispatch 

41def norm(x: _Numeric, ord: Optional[Any] = None, axis: Optional[int] = None): # type: ignore 

42 """ 

43 Matrix or vector norm. 

44 """ 

45 return torch.linalg.norm(x, ord=ord, dim=axis) 

46 

47 

48@dispatch 

49def logspace(start: B.TorchNumeric, stop: B.TorchNumeric, num: int = 50, base: _Numeric = 10.0): # type: ignore 

50 """ 

51 Return numbers spaced evenly on a log scale. 

52 """ 

53 return torch.logspace(start.item(), stop.item(), num, base) 

54 

55 

56@dispatch 

57def degree(a: B.TorchNumeric): # type: ignore 

58 """ 

59 Given an adjacency matrix `a`, return a diagonal matrix 

60 with the col-sums of `a` as main diagonal - this is the 

61 degree matrix representing the number of nodes each node 

62 is connected to. 

63 """ 

64 degrees = a.sum(axis=0) # type: ignore 

65 return torch.diag(degrees) 

66 

67 

68@dispatch 

69def eigenpairs(L: B.TorchNumeric, k: int): 

70 """ 

71 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues 

72 of a symmetric positive semi-definite matrix `L`. 

73 

74 TODO(AR): Replace with torch.lobpcg after sparse matrices are supported 

75 by torch. 

76 """ 

77 l, u = torch.linalg.eigh(L) 

78 return l[:k], u[:, :k] 

79 

80 

81@dispatch 

82def set_value(a: B.TorchNumeric, index: int, value: float): 

83 """ 

84 Set a[index] = value. 

85 This operation is not done in place and a new array is returned. 

86 """ 

87 a = a.clone() 

88 a[index] = value 

89 return a 

90 

91 

92@dispatch 

93def dtype_double(reference: B.TorchRandomState): # type: ignore 

94 """ 

95 Return `double` dtype of a backend based on the reference. 

96 """ 

97 return torch.double 

98 

99 

100@dispatch 

101def float_like(reference: B.TorchNumeric): 

102 """ 

103 Return the type of the reference if it is a floating point type. 

104 Otherwise return `double` dtype of a backend based on the reference. 

105 """ 

106 if torch.is_floating_point(reference): 

107 return B.dtype(reference) 

108 else: 

109 return torch.float64 

110 

111 

112@dispatch 

113def dtype_integer(reference: B.TorchRandomState): # type: ignore 

114 """ 

115 Return `int` dtype of a backend based on the reference. 

116 """ 

117 return torch.int 

118 

119 

120@dispatch 

121def int_like(reference: B.TorchNumeric): 

122 reference_dtype = reference.dtype 

123 if reference_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: 

124 return reference_dtype 

125 else: 

126 return torch.int32 

127 

128 

129@dispatch 

130def get_random_state(key: B.TorchRandomState): 

131 """ 

132 Return the random state of a random generator. 

133 

134 :param key: the random generator of type `B.TorchRandomState`. 

135 """ 

136 return key.get_state() 

137 

138 

139@dispatch 

140def restore_random_state(key: B.TorchRandomState, state): 

141 """ 

142 Set the random state of a random generator. Return the new random 

143 generator with state `state`. 

144 

145 :param key: 

146 The random generator of type `B.TorchRandomState`. 

147 :param state: 

148 The new random state of the random generator. 

149 """ 

150 gen = torch.Generator() 

151 gen.set_state(state) 

152 return gen 

153 

154 

155@dispatch 

156def create_complex(real: _Numeric, imag: B.TorchNumeric): 

157 """ 

158 Return a complex number with the given real and imaginary parts using pytorch. 

159 

160 :param real: 

161 float, real part of the complex number. 

162 :param imag: 

163 float, imaginary part of the complex number. 

164 """ 

165 complex_num = real + 1j * imag 

166 return complex_num 

167 

168 

169@dispatch 

170def complex_like(reference: B.TorchNumeric): 

171 """ 

172 Return `complex` dtype of a backend based on the reference. 

173 """ 

174 return B.promote_dtypes(torch.cfloat, reference.dtype) 

175 

176 

177@dispatch 

178def is_complex(reference: B.TorchNumeric): 

179 """ 

180 Return True if reference of `complex` dtype. 

181 """ 

182 return (B.dtype(reference) == torch.cfloat) or (B.dtype(reference) == torch.cdouble) 

183 

184 

185@dispatch 

186def cumsum(x: B.TorchNumeric, axis=None): 

187 """ 

188 Return cumulative sum (optionally along axis) 

189 """ 

190 return torch.cumsum(x, dim=axis) 

191 

192 

193@dispatch 

194def qr(x: B.TorchNumeric, mode="reduced"): 

195 """ 

196 Return a QR decomposition of a matrix x. 

197 """ 

198 Q, R = torch.linalg.qr(x, mode=mode) 

199 return Q, R 

200 

201 

202@dispatch 

203def slogdet(x: B.TorchNumeric): 

204 """ 

205 Return the sign and log-determinant of a matrix x. 

206 """ 

207 sign, logdet = torch.slogdet(x) 

208 return sign, logdet 

209 

210 

211@dispatch 

212def eigvalsh(x: B.TorchNumeric): 

213 """ 

214 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

215 """ 

216 return torch.linalg.eigvalsh(x) 

217 

218 

219@dispatch 

220def reciprocal_no_nan(x: B.TorchNumeric): 

221 """ 

222 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

223 """ 

224 safe_x = torch.where(x == 0.0, 1.0, x) 

225 return torch.where(x == 0.0, 0.0, torch.reciprocal(safe_x)) 

226 

227 

228@dispatch 

229def complex_conj(x: B.TorchNumeric): 

230 """ 

231 Return complex conjugate 

232 """ 

233 return torch.conj(x) 

234 

235 

236@dispatch 

237def logical_xor(x1: B.TorchNumeric, x2: B.TorchNumeric): 

238 """ 

239 Return logical XOR of two arrays. 

240 """ 

241 return torch.logical_xor(x1, x2) 

242 

243 

244@dispatch 

245def count_nonzero(x: B.TorchNumeric, axis=None): 

246 """ 

247 Count non-zero elements in an array. 

248 """ 

249 return torch.count_nonzero(x, dim=axis) 

250 

251 

252@dispatch 

253def dtype_bool(reference: B.TorchRandomState): # type: ignore 

254 """ 

255 Return `bool` dtype of a backend based on the reference. 

256 """ 

257 return torch.bool 

258 

259 

260@dispatch 

261def bool_like(reference: B.TorchNumeric): 

262 """ 

263 Return the type of the reference if it is of boolean type. 

264 Otherwise return `bool` dtype of a backend based on the reference. 

265 """ 

266 reference_dtype = reference.dtype 

267 if reference_dtype is torch.bool: 

268 return reference_dtype 

269 else: 

270 return torch.bool