Coverage for geometric_kernels / lab_extras / torch / extras.py: 92%
109 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
1from typing import TypeAlias
3import lab as B
4import torch
5from beartype.typing import Any
6from lab import dispatch
8_Numeric: TypeAlias = B.Number | B.TorchNumeric
11@dispatch
12def take_along_axis(a: _Numeric | B.Numeric, index: B.TorchNumeric, axis: int = 0) -> _Numeric: # type: ignore
13 """
14 Gathers elements of `a` along `axis` at `index` locations.
15 """
16 if not torch.is_tensor(a):
17 a = torch.tensor(a).to(index.device) # type: ignore
18 return torch.take_along_dim(a, index.long(), axis) # long is required by torch
21@dispatch
22def from_numpy(
23 a: B.TorchNumeric, b: list | B.Number | B.NPNumeric | B.TorchNumeric
24): # type: ignore
25 """
26 Converts the array `b` to a tensor of the same backend as `a`
27 """
28 if not torch.is_tensor(b):
29 b = torch.tensor(b.copy()).to(a.device) # type: ignore
30 return b
33@dispatch
34def trapz(y: B.TorchNumeric, x: _Numeric, dx: _Numeric = 1.0, axis: int = -1): # type: ignore
35 """
36 Integrate along the given axis using the trapezoidal rule.
37 """
38 return torch.trapz(y, x, dim=axis)
41@dispatch
42def norm(x: _Numeric, ord: Any | None = None, axis: int | None = None): # type: ignore
43 """
44 Matrix or vector norm.
45 """
46 return torch.linalg.norm(x, ord=ord, dim=axis)
49@dispatch
50def logspace(start: B.TorchNumeric, stop: B.TorchNumeric, num: int = 50, base: _Numeric = 10.0): # type: ignore
51 """
52 Return numbers spaced evenly on a log scale.
53 """
54 return torch.logspace(start.item(), stop.item(), num, base)
57@dispatch
58def degree(a: B.TorchNumeric): # type: ignore
59 """
60 Given an adjacency matrix `a`, return a diagonal matrix
61 with the col-sums of `a` as main diagonal - this is the
62 degree matrix representing the number of nodes each node
63 is connected to.
64 """
65 degrees = a.sum(axis=0) # type: ignore
66 return torch.diag(degrees)
69@dispatch
70def eigenpairs(L: B.TorchNumeric, k: int):
71 """
72 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues
73 of a symmetric positive semi-definite matrix `L`.
75 TODO(AR): Replace with torch.lobpcg after sparse matrices are supported
76 by torch.
77 """
78 l, u = torch.linalg.eigh(L)
79 return l[:k], u[:, :k]
82@dispatch
83def set_value(a: B.TorchNumeric, index: int, value: float):
84 """
85 Set a[index] = value.
86 This operation is not done in place and a new array is returned.
87 """
88 a = a.clone()
89 a[index] = value
90 return a
93@dispatch
94def dtype_double(reference: B.TorchRandomState): # type: ignore
95 """
96 Return `double` dtype of a backend based on the reference.
97 """
98 return torch.double
101@dispatch
102def float_like(reference: B.TorchNumeric):
103 """
104 Return the type of the reference if it is a floating point type.
105 Otherwise return `double` dtype of a backend based on the reference.
106 """
107 if torch.is_floating_point(reference):
108 return B.dtype(reference)
109 else:
110 return torch.float64
113@dispatch
114def dtype_integer(reference: B.TorchRandomState): # type: ignore
115 """
116 Return `int` dtype of a backend based on the reference.
117 """
118 return torch.int
121@dispatch
122def int_like(reference: B.TorchNumeric):
123 reference_dtype = reference.dtype
124 if reference_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
125 return reference_dtype
126 else:
127 return torch.int32
130@dispatch
131def get_random_state(key: B.TorchRandomState):
132 """
133 Return the random state of a random generator.
135 :param key: the random generator of type `B.TorchRandomState`.
136 """
137 return key.get_state()
140@dispatch
141def restore_random_state(key: B.TorchRandomState, state):
142 """
143 Set the random state of a random generator. Return the new random
144 generator with state `state`.
146 :param key:
147 The random generator of type `B.TorchRandomState`.
148 :param state:
149 The new random state of the random generator.
150 """
151 gen = torch.Generator()
152 gen.set_state(state)
153 return gen
156@dispatch
157def create_complex(real: _Numeric, imag: B.TorchNumeric):
158 """
159 Return a complex number with the given real and imaginary parts using pytorch.
161 :param real:
162 float, real part of the complex number.
163 :param imag:
164 float, imaginary part of the complex number.
165 """
166 complex_num = real + 1j * imag
167 return complex_num
170@dispatch
171def complex_like(reference: B.TorchNumeric):
172 """
173 Return `complex` dtype of a backend based on the reference.
174 """
175 return B.promote_dtypes(torch.cfloat, reference.dtype)
178@dispatch
179def is_complex(reference: B.TorchNumeric):
180 """
181 Return True if reference of `complex` dtype.
182 """
183 return (B.dtype(reference) == torch.cfloat) or (B.dtype(reference) == torch.cdouble)
186@dispatch
187def cumsum(x: B.TorchNumeric, axis=None):
188 """
189 Return cumulative sum (optionally along axis)
190 """
191 return torch.cumsum(x, dim=axis)
194@dispatch
195def qr(x: B.TorchNumeric, mode="reduced"):
196 """
197 Return a QR decomposition of a matrix x.
198 """
199 Q, R = torch.linalg.qr(x, mode=mode)
200 return Q, R
203@dispatch
204def slogdet(x: B.TorchNumeric):
205 """
206 Return the sign and log-determinant of a matrix x.
207 """
208 sign, logdet = torch.slogdet(x)
209 return sign, logdet
212@dispatch
213def eigvalsh(x: B.TorchNumeric):
214 """
215 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
216 """
217 return torch.linalg.eigvalsh(x)
220@dispatch
221def reciprocal_no_nan(x: B.TorchNumeric):
222 """
223 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
224 """
225 safe_x = torch.where(x == 0.0, 1.0, x)
226 return torch.where(x == 0.0, 0.0, torch.reciprocal(safe_x))
229@dispatch
230def complex_conj(x: B.TorchNumeric):
231 """
232 Return complex conjugate
233 """
234 return torch.conj(x)
237@dispatch
238def logical_xor(x1: B.TorchNumeric, x2: B.TorchNumeric):
239 """
240 Return logical XOR of two arrays.
241 """
242 return torch.logical_xor(x1, x2)
245@dispatch
246def count_nonzero(x: B.TorchNumeric, axis=None):
247 """
248 Count non-zero elements in an array.
249 """
250 return torch.count_nonzero(x, dim=axis)
253@dispatch
254def dtype_bool(reference: B.TorchRandomState): # type: ignore
255 """
256 Return `bool` dtype of a backend based on the reference.
257 """
258 return torch.bool
261@dispatch
262def bool_like(reference: B.TorchNumeric):
263 """
264 Return the type of the reference if it is of boolean type.
265 Otherwise return `bool` dtype of a backend based on the reference.
266 """
267 reference_dtype = reference.dtype
268 if reference_dtype is torch.bool:
269 return reference_dtype
270 else:
271 return torch.bool