Coverage for geometric_kernels/lab_extras/tensorflow/extras.py: 91%
109 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import sys
3import lab as B
4import tensorflow as tf
5import tensorflow_probability as tfp
6from beartype.typing import Any, List, Optional
7from lab import dispatch
8from plum import Union
10_Numeric = Union[B.Number, B.TFNumeric, B.NPNumeric]
13@dispatch
14def take_along_axis(a: _Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore
15 """
16 Gathers elements of `a` along `axis` at `index` locations.
17 """
18 if sys.version_info[:2] <= (3, 9):
19 index = tf.cast(index, tf.int32)
20 return tf.experimental.numpy.take_along_axis(
21 a, index, axis=axis
22 ) # the absence of explicit cast to int64 causes an error for Python 3.9 and below
25@dispatch
26def from_numpy(_: B.TFNumeric, b: Union[List, B.Numeric, B.NPNumeric, B.TFNumeric]): # type: ignore
27 """
28 Converts the array `b` to a tensor of the same backend as `a`
29 """
30 return tf.convert_to_tensor(b)
33@dispatch
34def trapz(y: _Numeric, x: _Numeric, dx=None, axis=-1): # type: ignore
35 """
36 Integrate along the given axis using the composite trapezoidal rule.
37 """
38 return tfp.math.trapz(y, x, dx, axis)
41@dispatch
42def norm(x: _Numeric, ord: Optional[Any] = None, axis: Optional[int] = None): # type: ignore
43 """
44 Matrix or vector norm.
45 """
46 return tf.norm(x, ord=ord, axis=axis)
49@dispatch
50def logspace(start: _Numeric, stop: _Numeric, num: int = 50, base: _Numeric = 50.0): # type: ignore
51 """
52 Return numbers spaced evenly on a log scale.
53 """
54 y = tf.linspace(start, stop, num)
55 return tf.math.pow(base, y)
58@dispatch
59def degree(a: B.TFNumeric): # type: ignore
60 """
61 Given an adjacency matrix `a`, return a diagonal matrix
62 with the col-sums of `a` as main diagonal - this is the
63 degree matrix representing the number of nodes each node
64 is connected to.
65 """
66 degrees = tf.reduce_sum(a, axis=0) # type: ignore
67 return tf.linalg.diag(degrees)
70@dispatch
71def eigenpairs(L: B.TFNumeric, k: int):
72 """
73 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues
74 of a symmetric positive semi-definite matrix `L`.
75 """
76 l, u = tf.linalg.eigh(L)
77 return l[:k], u[:, :k]
80@dispatch
81def set_value(a: B.TFNumeric, index: int, value: float):
82 """
83 Set a[index] = value.
84 This operation is not done in place and a new array is returned.
85 """
86 a = tf.where(tf.range(len(a)) == index, value, a)
87 return a
90@dispatch
91def dtype_double(reference: B.TFRandomState): # type: ignore
92 """
93 Return `double` dtype of a backend based on the reference.
94 """
95 return tf.float64
98@dispatch
99def float_like(reference: B.TFNumeric):
100 """
101 Return the type of the reference if it is a floating point type.
102 Otherwise return `double` dtype of a backend based on the reference.
103 """
104 reference_dtype = reference.dtype
105 if reference_dtype.is_floating:
106 return reference_dtype
107 else:
108 return tf.float64
111@dispatch
112def dtype_integer(reference: B.TFRandomState): # type: ignore
113 """
114 Return `int` dtype of a backend based on the reference.
115 """
116 return tf.int32
119@dispatch
120def int_like(reference: B.TFNumeric):
121 reference_dtype = reference.dtype
122 if reference_dtype.is_integer:
123 return reference_dtype
124 else:
125 return tf.int32
128@dispatch
129def get_random_state(key: B.TFRandomState):
130 """
131 Return the random state of a random generator.
133 :param key:
134 The random generator of type `B.TFRandomState`.
135 """
136 return tf.identity(key.state), key.algorithm
139@dispatch
140def restore_random_state(key: B.TFRandomState, state):
141 """
142 Set the random state of a random generator. Return the new random
143 generator with state `state`.
145 :param key:
146 The random generator.
147 :param state:
148 The new random state of the random generator of type `B.TFRandomState`.
149 """
150 gen = tf.random.Generator.from_state(state=tf.identity(state[0]), alg=state[1])
151 return gen
154@dispatch
155def create_complex(real: _Numeric, imag: B.TFNumeric):
156 """
157 Return a complex number with the given real and imaginary parts using tensorflow.
159 :param real:
160 float, real part of the complex number.
161 :param imag:
162 float, imaginary part of the complex number.
163 """
164 complex_num = tf.complex(B.cast(B.dtype(imag), from_numpy(imag, real)), imag)
165 return complex_num
168@dispatch
169def complex_like(reference: B.TFNumeric):
170 """
171 Return `complex` dtype of a backend based on the reference.
172 """
173 return B.promote_dtypes(tf.complex64, reference.dtype)
176@dispatch
177def is_complex(reference: B.TFNumeric):
178 """
179 Return True if reference of `complex` dtype.
180 """
181 return (B.dtype(reference) == tf.complex64) or (B.dtype(reference) == tf.complex128)
184@dispatch
185def cumsum(x: B.TFNumeric, axis=None):
186 """
187 Return cumulative sum (optionally along axis)
188 """
189 return tf.math.cumsum(x, axis=axis)
192@dispatch
193def qr(x: B.TFNumeric, mode="reduced"):
194 """
195 Return a QR decomposition of a matrix x.
196 """
197 full_matrices = mode == "complete"
198 Q, R = tf.linalg.qr(x, full_matrices=full_matrices)
199 return Q, R
202@dispatch
203def slogdet(x: B.TFNumeric):
204 """
205 Return the sign and log-determinant of a matrix x.
206 """
207 sign, logdet = tf.linalg.slogdet(x)
208 return sign, logdet
211@dispatch
212def eigvalsh(x: B.TFNumeric):
213 """
214 Compute the eigenvalues of a Hermitian or real symmetric matrix x.
215 """
216 return tf.linalg.eigvalsh(x)
219@dispatch
220def reciprocal_no_nan(x: B.TFNumeric):
221 """
222 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0.
223 """
224 return tf.math.reciprocal_no_nan(x)
227@dispatch
228def complex_conj(x: B.TFNumeric):
229 """
230 Return complex conjugate
231 """
232 return tf.math.conj(x)
235@dispatch
236def logical_xor(x1: B.TFNumeric, x2: B.TFNumeric):
237 """
238 Return logical XOR of two arrays.
239 """
240 return tf.math.logical_xor(x1, x2)
243@dispatch
244def count_nonzero(x: B.TFNumeric, axis=None):
245 """
246 Count non-zero elements in an array.
247 """
248 return tf.math.count_nonzero(x, axis=axis)
251@dispatch
252def dtype_bool(reference: B.TFRandomState): # type: ignore
253 """
254 Return `bool` dtype of a backend based on the reference.
255 """
256 return tf.bool
259@dispatch
260def bool_like(reference: B.NPNumeric):
261 """
262 Return the type of the reference if it is of boolean type.
263 Otherwise return `bool` dtype of a backend based on the reference.
264 """
265 reference_dtype = reference.dtype
266 if reference_dtype.is_bool:
267 return reference_dtype
268 else:
269 return tf.bool