Coverage for geometric_kernels/utils/utils.py: 64%
163 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2Convenience utilities.
3"""
5import inspect
6import sys
7from contextlib import contextmanager
8from importlib import resources as impresources
9from itertools import combinations
11import einops
12import lab as B
13import numpy as np
14from beartype.typing import Callable, Generator, List, Set, Tuple
16from geometric_kernels import resources
17from geometric_kernels.lab_extras import (
18 count_nonzero,
19 get_random_state,
20 logical_xor,
21 restore_random_state,
22)
25def chain(elements: B.Numeric, repetitions: List[int]) -> B.Numeric:
26 """
27 Repeats each element in `elements` by a certain number of repetitions as
28 specified in `repetitions`. The length of `elements` and `repetitions`
29 should match.
31 :param elements:
32 An [N,]-shaped array of elements to repeat.
33 :param repetitions:
34 A list specifying the number of types to repeat each of the elements in
35 `elements`. The length of `repetitions` should be equal to N.
37 :return:
38 An [M,]-shaped array.
40 EXAMPLE:
42 .. code-block:: python
44 elements = np.array([1, 2, 3])
45 repetitions = [2, 1, 3]
46 out = chain(elements, repetitions)
47 print(out) # [1, 1, 2, 3, 3, 3]
48 """
49 values = [
50 einops.repeat(elements[i : i + 1], "j -> (tile j)", tile=repetitions[i])
51 for i in range(len(repetitions))
52 ]
53 return B.concat(*values, axis=0)
56def make_deterministic(f: Callable, key: B.RandomState) -> Callable:
57 """
58 Returns a deterministic version of a function that uses a random
59 number generator.
61 :param f:
62 The function to make deterministic.
63 :param key:
64 The key used to generate the random state.
66 :return:
67 A function representing the deterministic version of the input function.
69 .. note::
70 This function assumes that the input function has a 'key' argument
71 or keyword-only argument that is used to generate random numbers.
72 Otherwise, the function is returned as is.
74 EXAMPLE:
76 .. code-block:: python
78 key = tf.random.Generator.from_seed(1234)
79 feature_map = default_feature_map(kernel=base_kernel)
80 sample_paths = make_deterministic(sampler(feature_map), key)
81 _, ys_train = sample_paths(xs_train, params)
82 key, ys_test = sample_paths(xs_test, params)
83 """
84 f_argspec = inspect.getfullargspec(f)
85 f_varnames = f_argspec.args
86 key_argtype = None
87 if "key" in f_varnames:
88 key_argtype = "pos"
89 key_position = f_varnames.index("key")
90 elif "key" in f_argspec.kwonlyargs:
91 key_argtype = "kwonly"
93 if key_argtype is None:
94 return f # already deterministic
96 saved_random_state = get_random_state(key)
98 def deterministic_f(*args, **kwargs):
99 restored_key = restore_random_state(key, saved_random_state)
100 if key_argtype == "kwonly":
101 kwargs["key"] = restored_key
102 new_args = args
103 elif key_argtype == "pos":
104 new_args = args[:key_position] + (restored_key,) + args[key_position:]
105 else:
106 raise ValueError("Unknown key_argtype %s" % key_argtype)
107 return f(*new_args, **kwargs)
109 if hasattr(f, "__name__"):
110 f_name = f.__name__
111 elif hasattr(f, "__class__") and hasattr(f.__class__, "__name__"):
112 f_name = f.__class__.__name__
113 else:
114 f_name = "<anonymous>"
116 if hasattr(f, "__doc__"):
117 f_doc = f.__doc__
118 else:
119 f_doc = ""
121 new_docstring = f"""
122 This is a deterministic version of the function {f_name}.
124 The original docstring follows.
126 {f_doc}
127 """
128 deterministic_f.__doc__ = new_docstring
130 return deterministic_f
133def ordered_pairwise_differences(X: B.Numeric) -> B.Numeric:
134 """
135 Compute the ordered pairwise differences between elements of a vector.
137 :param X:
138 A [..., D]-shaped array, a batch of D-dimensional vectors.
140 :return:
141 A [..., C]-shaped array, where C = D*(D-1)//2, containing the ordered
142 pairwise differences between the elements of X. That is, the array
143 containing differences X[...,i] - X[...,j] for all i < j.
144 """
145 diffX = B.expand_dims(X, -2) - B.expand_dims(X, -1) # [B, D, D]
146 # diffX[i, j] = X[j] - X[i]
147 # lower triangle is i > j
148 # so, lower triangle is X[k] - X[l] with k < l
150 diffX = B.tril_to_vec(diffX, offset=-1) # don't take the diagonal
152 return diffX
155def fixed_length_partitions( # noqa: C901
156 n: int, L: int
157) -> Generator[List[int], None, None]:
158 """
159 A generator for integer partitions of n into L parts, in colex order.
161 :param n:
162 The number to partition.
163 :param L:
164 Size of partitions.
166 Developed by D. Eppstein in 2005, taken from
167 https://www.ics.uci.edu/~eppstein/PADS/IntegerPartitions.py
168 The algorithm follows Knuth v4 fasc3 p38 in rough outline;
169 Knuth credits it to Hindenburg, 1779.
170 """
172 # guard against special cases
173 if L == 0:
174 if n == 0:
175 yield []
176 return
177 if L == 1:
178 if n > 0:
179 yield [n]
180 return
181 if n < L:
182 return
184 partition = [n - L + 1] + (L - 1) * [1]
185 while True:
186 yield partition.copy()
187 if partition[0] - 1 > partition[1]:
188 partition[0] -= 1
189 partition[1] += 1
190 continue
191 j = 2
192 s = partition[0] + partition[1] - 1
193 while j < L and partition[j] >= partition[0] - 1:
194 s += partition[j]
195 j += 1
196 if j >= L:
197 return
198 partition[j] = x = partition[j] + 1
199 j -= 1
200 while j > 0:
201 partition[j] = x
202 s -= x
203 j -= 1
204 partition[0] = s
207def partition_dominance_cone(partition: Tuple[int, ...]) -> Set[Tuple[int, ...]]:
208 """
209 Calculates partitions dominated by a given one and having the same number
210 of parts (including the zero parts of the original).
212 :param partition:
213 A partition.
215 :return:
216 A set of partitions.
217 """
218 cone = {partition}
219 new_partitions: Set[Tuple[int, ...]] = {(0,)}
220 prev_partitions = cone
221 while new_partitions:
222 new_partitions = set()
223 for partition in prev_partitions:
224 for i in range(len(partition) - 1):
225 if partition[i] > partition[i + 1]:
226 for j in range(i + 1, len(partition)):
227 if (
228 partition[i] > partition[j] + 1
229 and partition[j] < partition[j - 1]
230 ):
231 new_partition = list(partition)
232 new_partition[i] -= 1
233 new_partition[j] += 1
234 new_partition = tuple(new_partition)
235 if new_partition not in cone:
236 new_partitions.add(new_partition)
237 cone.update(new_partitions)
238 prev_partitions = new_partitions
239 return cone
242def partition_dominance_or_subpartition_cone(
243 partition: Tuple[int, ...],
244) -> Set[Tuple[int, ...]]:
245 """
246 Calculates subpartitions and partitions dominated by a given one and having
247 the same number of parts (including zero parts of the original).
249 :param partition:
250 A partition.
252 :return:
253 A set of partitions.
254 """
255 cone = {partition}
256 new_partitions: Set[Tuple[int, ...]] = {(0,)}
257 prev_partitions = cone
258 while new_partitions:
259 new_partitions = set()
260 for partition in prev_partitions:
261 for i in range(len(partition) - 1):
262 if partition[i] > partition[i + 1]:
263 new_partition = list(partition)
264 new_partition[i] -= 1
265 new_partition = tuple(new_partition)
266 if new_partition not in cone:
267 new_partitions.add(new_partition)
268 for j in range(i + 1, len(partition)):
269 if (
270 partition[i] > partition[j] + 1
271 and partition[j] < partition[j - 1]
272 ):
273 new_partition = list(partition) # type: ignore[assignment]
274 new_partition[i] -= 1 # type: ignore[index]
275 new_partition[j] += 1 # type: ignore[index]
276 new_partition = tuple(new_partition)
277 if new_partition not in cone:
278 new_partitions.add(new_partition)
279 cone.update(new_partitions)
280 prev_partitions = new_partitions
281 return cone
284@contextmanager
285def get_resource_file_path(filename: str):
286 """
287 A contextmanager wrapper around `impresources` that supports both
288 Python>=3.9 and Python==3.8 with a unified interface.
290 :param filename:
291 The name of the file.
292 """
293 if sys.version_info >= (3, 9):
294 with impresources.as_file(impresources.files(resources) / filename) as path:
295 yield path
296 else:
297 with impresources.path(resources, filename) as path:
298 yield path
301def hamming_distance(x1: B.Bool, x2: B.Bool):
302 """
303 Hamming distance between two batches of boolean vectors.
305 :param x1:
306 Array of any backend, of shape [N, D].
307 :param x2:
308 Array of any backend, of shape [M, D].
310 :return:
311 An array of shape [N, M] whose entry n, m contains the Hamming distance
312 between x1[n, :] and x2[m, :]. It is of the same backend as x1 and x2.
313 """
314 return count_nonzero(logical_xor(x1[:, None, :], x2[None, :, :]), axis=-1)
317def log_binomial(n: B.Int, k: B.Int) -> B.Float:
318 """
319 Compute the logarithm of the binomial coefficient.
321 :param n:
322 The number of elements in the set.
323 :param k:
324 The number of elements to choose.
326 :return:
327 The logarithm of the binomial coefficient binom(n, k).
328 """
329 if not B.all(0 <= k <= n):
330 raise ValueError("Incorrect parameters of the binomial coefficient.")
332 return B.loggamma(n + 1) - B.loggamma(k + 1) - B.loggamma(n - k + 1)
335def binary_vectors_and_subsets(d: int):
336 r"""
337 Generates all possible binary vectors of size d and all possible subsets of
338 the set $\{0, .., d-1\}$ as a byproduct.
340 :param d:
341 The dimension of binary vectors and the size of the set to take subsets of.
343 :return:
344 A tuple (x, combs), where x is a matrix of size (2**d, d) whose rows
345 are all possible binary vectors of size d, and combs is a list of all
346 possible subsets of the set $\{0, .., d-1\}$, each subset being
347 represented by a list of integers itself.
348 """
349 x = np.zeros((2**d, d), dtype=bool)
350 combs = []
351 i = 0
352 for level in range(d + 1):
353 for cur_combination in combinations(range(d), level):
354 indices = list(cur_combination)
355 combs.append(indices)
356 x[i, indices] = 1
357 i += 1
359 return x, combs
362def _check_field_in_params(params, field):
363 """
364 Raise an error if `params` does not contain a `field`.
365 """
366 if field not in params:
367 raise ValueError(f"`params` must contain `{field}`.")
370def _check_1_vector(x, desc):
371 """
372 Raise an error if `x` is not a vector of shape [1,].
373 """
374 if B.shape(x) != (1,):
375 raise ValueError(
376 f"`{desc}` must have shape `[1,]`, but has shape {B.shape(x)}."
377 )
380def _check_rank_1_array(x, desc):
381 """
382 Raise an error if `x` is not a rank-1 array.
383 """
384 if B.rank(x) != 1:
385 raise ValueError(
386 f"`{desc}` must have 1 dimension (`ndim` == 1), but has shape {B.shape(x)}."
387 )
390def _check_matrix(x, desc):
391 """
392 Raise an error if `x` is not a matrix.
393 """
394 if B.rank(x) != 2:
395 raise ValueError(f"`{desc}` must be a matrix, but has shape {B.shape(x)}.")