Coverage for geometric_kernels / lab_extras / tensorflow / extras.py: 91%
106 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 15:49 +0000
1from typing import TypeAlias
3import lab as B
4import tensorflow as tf
5import tensorflow_probability as tfp
6from beartype.typing import Any
7from lab import dispatch
9_Numeric: TypeAlias = B.Number | B.TFNumeric | B.NPNumeric
12@dispatch
13def take_along_axis(a: _Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore
14 """
15 Gathers elements of `a` along `axis` at `index` locations.
16 """
17 return tf.experimental.numpy.take_along_axis(a, index, axis=axis)
20@dispatch
21def from_numpy(_: B.TFNumeric, b: list | B.Numeric | B.NPNumeric | B.TFNumeric): # type: ignore
22 """
23 Converts the array `b` to a tensor of the same backend as `a`
24 """
25 return tf.convert_to_tensor(b)
28@dispatch
29def trapz(y: _Numeric, x: _Numeric, dx=None, axis=-1): # type: ignore
30 """
31 Integrate along the given axis using the composite trapezoidal rule.
32 """
33 return tfp.math.trapz(y, x, dx, axis)
36@dispatch
37def norm(x: _Numeric, ord: Any | None = None, axis: int | None = None): # type: ignore
38 """
39 Matrix or vector norm.
40 """
41 return tf.norm(x, ord=ord, axis=axis)
44@dispatch
45def logspace(start: _Numeric, stop: _Numeric, num: int = 50, base: _Numeric = 50.0): # type: ignore
46 """
47 Return numbers spaced evenly on a log scale.
48 """
49 y = tf.linspace(start, stop, num)
50 return tf.math.pow(base, y)
53@dispatch
54def degree(a: B.TFNumeric): # type: ignore
55 """
56 Given an adjacency matrix `a`, return a diagonal matrix
57 with the col-sums of `a` as main diagonal - this is the
58 degree matrix representing the number of nodes each node
59 is connected to.
60 """
61 degrees = tf.reduce_sum(a, axis=0) # type: ignore
62 return tf.linalg.diag(degrees)
65@dispatch
66def eigenpairs(L: B.TFNumeric, k: int):
67 """
68 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues
69 of a symmetric positive semi-definite matrix `L`.
70 """
71 l, u = tf.linalg.eigh(L)
72 return l[:k], u[:, :k]
75@dispatch
76def set_value(a: B.TFNumeric, index: int, value: float):
77 """
78 Set a[index] = value.
79 This operation is not done in place and a new array is returned.
80 """
81 a = tf.where(tf.range(len(a)) == index, value, a)
82 return a
85@dispatch
86def dtype_double(reference: B.TFRandomState): # type: ignore
87 """
88 Return `double` dtype of a backend based on the reference.
89 """
90 return tf.float64
93@dispatch
94def float_like(reference: B.TFNumeric):
95 """
96 Return the type of the reference if it is a floating point type.
97 Otherwise return `double` dtype of a backend based on the reference.
98 """
99 reference_dtype = reference.dtype
100 if reference_dtype.is_floating:
101 return reference_dtype
102 else:
103 return tf.float64
106@dispatch
107def dtype_integer(reference: B.TFRandomState): # type: ignore
108 """
109 Return `int` dtype of a backend based on the reference.
110 """
111 return tf.int32
114@dispatch
115def int_like(reference: B.TFNumeric):
116 reference_dtype = reference.dtype
117 if reference_dtype.is_integer:
118 return reference_dtype
119 else:
120 return tf.int32
123@dispatch
124def get_random_state(key: B.TFRandomState):
125 """
126 Return the random state of a random generator.
128 :param key:
129 The random generator of type `B.TFRandomState`.
130 """
131 return tf.identity(key.state), key.algorithm
134@dispatch
135def restore_random_state(key: B.TFRandomState, state):
136 """
137 Set the random state of a random generator. Return the new random
138 generator with state `state`.
140 :param key:
141 The random generator.
142 :param state:
143 The new random state of the random generator of type `B.TFRandomState`.
144 """
145 gen = tf.random.Generator.from_state(state=tf.identity(state[0]), alg=state[1])
146 return gen
149@dispatch
150def create_complex(real: _Numeric, imag: B.TFNumeric):
151 """
152 Return a complex number with the given real and imaginary parts using tensorflow.
154 :param real:
155 float, real part of the complex number.
156 :param imag:
157 float, imaginary part of the complex number.
158 """
159 complex_num = tf.complex(B.cast(B.dtype(imag), from_numpy(imag, real)), imag)
160 return complex_num
163@dispatch
164def complex_like(reference: B.TFNumeric):
165 """
166 Return `complex` dtype of a backend based on the reference.
167 """
168 return B.promote_dtypes(tf.complex64, reference.dtype)
171@dispatch
172def is_complex(reference: B.TFNumeric):
173 """
174 Return True if reference of `complex` dtype.
175 """
176 return (B.dtype(reference) == tf.complex64) or (B.dtype(reference) == tf.complex128)
179@dispatch
180def cumsum(x: B.TFNumeric, axis=None):
181 """
182 Return cumulative sum (optionally along axis)
183 """
184 return tf.math.cumsum(x, axis=axis)
187@dispatch
188def qr(x: B.TFNumeric, mode="reduced"):
189 """
190 Return a QR decomposition of a matrix x.
191 """
192 full_matrices = mode == "complete"
193 Q, R = tf.linalg.qr(x, full_matrices=full_matrices)
194 return Q, R
197@dispatch
198def slogdet(x: B.TFNumeric):
199 """
200 Return the sign and log-determinant of a matrix x.
201 """
202 sign, logdet = tf.linalg.slogdet(x)
203 return sign, logdet
206@dispatch
207def eigvalsh(x: B.TFNumeric):
208 """
209 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
210 """
211 return tf.linalg.eigvalsh(x)
214@dispatch
215def reciprocal_no_nan(x: B.TFNumeric):
216 """
217 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
218 """
219 return tf.math.reciprocal_no_nan(x)
222@dispatch
223def complex_conj(x: B.TFNumeric):
224 """
225 Return complex conjugate
226 """
227 return tf.math.conj(x)
230@dispatch
231def logical_xor(x1: B.TFNumeric, x2: B.TFNumeric):
232 """
233 Return logical XOR of two arrays.
234 """
235 return tf.math.logical_xor(x1, x2)
238@dispatch
239def count_nonzero(x: B.TFNumeric, axis=None):
240 """
241 Count non-zero elements in an array.
242 """
243 return tf.math.count_nonzero(x, axis=axis)
246@dispatch
247def dtype_bool(reference: B.TFRandomState): # type: ignore
248 """
249 Return `bool` dtype of a backend based on the reference.
250 """
251 return tf.bool
254@dispatch
255def bool_like(reference: B.NPNumeric):
256 """
257 Return the type of the reference if it is of boolean type.
258 Otherwise return `bool` dtype of a backend based on the reference.
259 """
260 reference_dtype = reference.dtype
261 if reference_dtype.is_bool:
262 return reference_dtype
263 else:
264 return tf.bool