Coverage for geometric_kernels/lab_extras/torch/extras.py: 93%
109 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import lab as B
2import torch
3from beartype.typing import Any, List, Optional
4from lab import dispatch
5from plum import Union
7_Numeric = Union[B.Number, B.TorchNumeric]
10@dispatch
11def take_along_axis(a: Union[_Numeric, B.Numeric], index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore
12 """
13 Gathers elements of `a` along `axis` at `index` locations.
14 """
15 if not torch.is_tensor(a):
16 a = torch.tensor(a).to(index.device) # type: ignore
17 return torch.take_along_dim(a, index.long(), axis) # long is required by torch
20@dispatch
21def from_numpy(
22 a: B.TorchNumeric, b: Union[List, B.Number, B.NPNumeric, B.TorchNumeric]
23): # type: ignore
24 """
25 Converts the array `b` to a tensor of the same backend as `a`
26 """
27 if not torch.is_tensor(b):
28 b = torch.tensor(b.copy()).to(a.device) # type: ignore
29 return b
32@dispatch
33def trapz(y: B.TorchNumeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore
34 """
35 Integrate along the given axis using the trapezoidal rule.
36 """
37 return torch.trapz(y, x, dim=axis)
40@dispatch
41def norm(x: _Numeric, ord: Optional[Any] = None, axis: Optional[int] = None): # type: ignore
42 """
43 Matrix or vector norm.
44 """
45 return torch.linalg.norm(x, ord=ord, dim=axis)
48@dispatch
49def logspace(start: B.TorchNumeric, stop: B.TorchNumeric, num: int = 50, base: _Numeric = 10.0): # type: ignore
50 """
51 Return numbers spaced evenly on a log scale.
52 """
53 return torch.logspace(start.item(), stop.item(), num, base)
56@dispatch
57def degree(a: B.TorchNumeric): # type: ignore
58 """
59 Given an adjacency matrix `a`, return a diagonal matrix
60 with the col-sums of `a` as main diagonal - this is the
61 degree matrix representing the number of nodes each node
62 is connected to.
63 """
64 degrees = a.sum(axis=0) # type: ignore
65 return torch.diag(degrees)
68@dispatch
69def eigenpairs(L: B.TorchNumeric, k: int):
70 """
71 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues
72 of a symmetric positive semi-definite matrix `L`.
74 TODO(AR): Replace with torch.lobpcg after sparse matrices are supported
75 by torch.
76 """
77 l, u = torch.linalg.eigh(L)
78 return l[:k], u[:, :k]
81@dispatch
82def set_value(a: B.TorchNumeric, index: int, value: float):
83 """
84 Set a[index] = value.
85 This operation is not done in place and a new array is returned.
86 """
87 a = a.clone()
88 a[index] = value
89 return a
92@dispatch
93def dtype_double(reference: B.TorchRandomState): # type: ignore
94 """
95 Return `double` dtype of a backend based on the reference.
96 """
97 return torch.double
100@dispatch
101def float_like(reference: B.TorchNumeric):
102 """
103 Return the type of the reference if it is a floating point type.
104 Otherwise return `double` dtype of a backend based on the reference.
105 """
106 if torch.is_floating_point(reference):
107 return B.dtype(reference)
108 else:
109 return torch.float64
112@dispatch
113def dtype_integer(reference: B.TorchRandomState): # type: ignore
114 """
115 Return `int` dtype of a backend based on the reference.
116 """
117 return torch.int
120@dispatch
121def int_like(reference: B.TorchNumeric):
122 reference_dtype = reference.dtype
123 if reference_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
124 return reference_dtype
125 else:
126 return torch.int32
129@dispatch
130def get_random_state(key: B.TorchRandomState):
131 """
132 Return the random state of a random generator.
134 :param key: the random generator of type `B.TorchRandomState`.
135 """
136 return key.get_state()
139@dispatch
140def restore_random_state(key: B.TorchRandomState, state):
141 """
142 Set the random state of a random generator. Return the new random
143 generator with state `state`.
145 :param key:
146 The random generator of type `B.TorchRandomState`.
147 :param state:
148 The new random state of the random generator.
149 """
150 gen = torch.Generator()
151 gen.set_state(state)
152 return gen
155@dispatch
156def create_complex(real: _Numeric, imag: B.TorchNumeric):
157 """
158 Return a complex number with the given real and imaginary parts using pytorch.
160 :param real:
161 float, real part of the complex number.
162 :param imag:
163 float, imaginary part of the complex number.
164 """
165 complex_num = real + 1j * imag
166 return complex_num
169@dispatch
170def complex_like(reference: B.TorchNumeric):
171 """
172 Return `complex` dtype of a backend based on the reference.
173 """
174 return B.promote_dtypes(torch.cfloat, reference.dtype)
177@dispatch
178def is_complex(reference: B.TorchNumeric):
179 """
180 Return True if reference of `complex` dtype.
181 """
182 return (B.dtype(reference) == torch.cfloat) or (B.dtype(reference) == torch.cdouble)
185@dispatch
186def cumsum(x: B.TorchNumeric, axis=None):
187 """
188 Return cumulative sum (optionally along axis)
189 """
190 return torch.cumsum(x, dim=axis)
193@dispatch
194def qr(x: B.TorchNumeric, mode="reduced"):
195 """
196 Return a QR decomposition of a matrix x.
197 """
198 Q, R = torch.linalg.qr(x, mode=mode)
199 return Q, R
202@dispatch
203def slogdet(x: B.TorchNumeric):
204 """
205 Return the sign and log-determinant of a matrix x.
206 """
207 sign, logdet = torch.slogdet(x)
208 return sign, logdet
211@dispatch
212def eigvalsh(x: B.TorchNumeric):
213 """
214 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
215 """
216 return torch.linalg.eigvalsh(x)
219@dispatch
220def reciprocal_no_nan(x: B.TorchNumeric):
221 """
222 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
223 """
224 safe_x = torch.where(x == 0.0, 1.0, x)
225 return torch.where(x == 0.0, 0.0, torch.reciprocal(safe_x))
228@dispatch
229def complex_conj(x: B.TorchNumeric):
230 """
231 Return complex conjugate
232 """
233 return torch.conj(x)
236@dispatch
237def logical_xor(x1: B.TorchNumeric, x2: B.TorchNumeric):
238 """
239 Return logical XOR of two arrays.
240 """
241 return torch.logical_xor(x1, x2)
244@dispatch
245def count_nonzero(x: B.TorchNumeric, axis=None):
246 """
247 Count non-zero elements in an array.
248 """
249 return torch.count_nonzero(x, dim=axis)
252@dispatch
253def dtype_bool(reference: B.TorchRandomState): # type: ignore
254 """
255 Return `bool` dtype of a backend based on the reference.
256 """
257 return torch.bool
260@dispatch
261def bool_like(reference: B.TorchNumeric):
262 """
263 Return the type of the reference if it is of boolean type.
264 Otherwise return `bool` dtype of a backend based on the reference.
265 """
266 reference_dtype = reference.dtype
267 if reference_dtype is torch.bool:
268 return reference_dtype
269 else:
270 return torch.bool