Coverage for geometric_kernels / utils / utils.py: 64%
163 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
1"""
2Convenience utilities.
3"""
5import inspect
6import sys
7from contextlib import contextmanager
8from importlib import resources as impresources
9from itertools import combinations
11import einops
12import lab as B
13import numpy as np
14from beartype.typing import Callable, Generator, List, Set, Tuple
16from geometric_kernels import resources
17from geometric_kernels.lab_extras import (
18 count_nonzero,
19 get_random_state,
20 restore_random_state,
21)
24def chain(elements: B.Numeric, repetitions: List[int]) -> B.Numeric:
25 """
26 Repeats each element in `elements` by a certain number of repetitions as
27 specified in `repetitions`. The length of `elements` and `repetitions`
28 should match.
30 :param elements:
31 An [N,]-shaped array of elements to repeat.
32 :param repetitions:
33 A list specifying the number of types to repeat each of the elements in
34 `elements`. The length of `repetitions` should be equal to N.
36 :return:
37 An [M,]-shaped array.
39 EXAMPLE:
41 .. code-block:: python
43 elements = np.array([1, 2, 3])
44 repetitions = [2, 1, 3]
45 out = chain(elements, repetitions)
46 print(out) # [1, 1, 2, 3, 3, 3]
47 """
48 values = [
49 einops.repeat(elements[i : i + 1], "j -> (tile j)", tile=repetitions[i])
50 for i in range(len(repetitions))
51 ]
52 return B.concat(*values, axis=0)
55def make_deterministic(f: Callable, key: B.RandomState) -> Callable:
56 """
57 Returns a deterministic version of a function that uses a random
58 number generator.
60 :param f:
61 The function to make deterministic.
62 :param key:
63 The key used to generate the random state.
65 :return:
66 A function representing the deterministic version of the input function.
68 .. note::
69 This function assumes that the input function has a 'key' argument
70 or keyword-only argument that is used to generate random numbers.
71 Otherwise, the function is returned as is.
73 EXAMPLE:
75 .. code-block:: python
77 key = tf.random.Generator.from_seed(1234)
78 feature_map = default_feature_map(kernel=base_kernel)
79 sample_paths = make_deterministic(sampler(feature_map), key)
80 _, ys_train = sample_paths(xs_train, params)
81 key, ys_test = sample_paths(xs_test, params)
82 """
83 f_argspec = inspect.getfullargspec(f)
84 f_varnames = f_argspec.args
85 key_argtype = None
86 if "key" in f_varnames:
87 key_argtype = "pos"
88 key_position = f_varnames.index("key")
89 elif "key" in f_argspec.kwonlyargs:
90 key_argtype = "kwonly"
92 if key_argtype is None:
93 return f # already deterministic
95 saved_random_state = get_random_state(key)
97 def deterministic_f(*args, **kwargs):
98 restored_key = restore_random_state(key, saved_random_state)
99 if key_argtype == "kwonly":
100 kwargs["key"] = restored_key
101 new_args = args
102 elif key_argtype == "pos":
103 new_args = args[:key_position] + (restored_key,) + args[key_position:]
104 else:
105 raise ValueError("Unknown key_argtype %s" % key_argtype)
106 return f(*new_args, **kwargs)
108 if hasattr(f, "__name__"):
109 f_name = f.__name__
110 elif hasattr(f, "__class__") and hasattr(f.__class__, "__name__"):
111 f_name = f.__class__.__name__
112 else:
113 f_name = "<anonymous>"
115 if hasattr(f, "__doc__"):
116 f_doc = f.__doc__
117 else:
118 f_doc = ""
120 new_docstring = f"""
121 This is a deterministic version of the function {f_name}.
123 The original docstring follows.
125 {f_doc}
126 """
127 deterministic_f.__doc__ = new_docstring
129 return deterministic_f
132def ordered_pairwise_differences(X: B.Numeric) -> B.Numeric:
133 """
134 Compute the ordered pairwise differences between elements of a vector.
136 :param X:
137 A [..., D]-shaped array, a batch of D-dimensional vectors.
139 :return:
140 A [..., C]-shaped array, where C = D*(D-1)//2, containing the ordered
141 pairwise differences between the elements of X. That is, the array
142 containing differences X[...,i] - X[...,j] for all i < j.
143 """
144 diffX = B.expand_dims(X, -2) - B.expand_dims(X, -1) # [B, D, D]
145 # diffX[i, j] = X[j] - X[i]
146 # lower triangle is i > j
147 # so, lower triangle is X[k] - X[l] with k < l
149 diffX = B.tril_to_vec(diffX, offset=-1) # don't take the diagonal
151 return diffX
154def fixed_length_partitions( # noqa: C901
155 n: int, L: int
156) -> Generator[List[int], None, None]:
157 """
158 A generator for integer partitions of n into L parts, in colex order.
160 :param n:
161 The number to partition.
162 :param L:
163 Size of partitions.
165 Developed by D. Eppstein in 2005, taken from
166 https://www.ics.uci.edu/~eppstein/PADS/IntegerPartitions.py
167 The algorithm follows Knuth v4 fasc3 p38 in rough outline;
168 Knuth credits it to Hindenburg, 1779.
169 """
171 # guard against special cases
172 if L == 0:
173 if n == 0:
174 yield []
175 return
176 if L == 1:
177 if n > 0:
178 yield [n]
179 return
180 if n < L:
181 return
183 partition = [n - L + 1] + (L - 1) * [1]
184 while True:
185 yield partition.copy()
186 if partition[0] - 1 > partition[1]:
187 partition[0] -= 1
188 partition[1] += 1
189 continue
190 j = 2
191 s = partition[0] + partition[1] - 1
192 while j < L and partition[j] >= partition[0] - 1:
193 s += partition[j]
194 j += 1
195 if j >= L:
196 return
197 partition[j] = x = partition[j] + 1
198 j -= 1
199 while j > 0:
200 partition[j] = x
201 s -= x
202 j -= 1
203 partition[0] = s
206def partition_dominance_cone(partition: Tuple[int, ...]) -> Set[Tuple[int, ...]]:
207 """
208 Calculates partitions dominated by a given one and having the same number
209 of parts (including the zero parts of the original).
211 :param partition:
212 A partition.
214 :return:
215 A set of partitions.
216 """
217 cone = {partition}
218 new_partitions: Set[Tuple[int, ...]] = {(0,)}
219 prev_partitions = cone
220 while new_partitions:
221 new_partitions = set()
222 for partition in prev_partitions:
223 for i in range(len(partition) - 1):
224 if partition[i] > partition[i + 1]:
225 for j in range(i + 1, len(partition)):
226 if (
227 partition[i] > partition[j] + 1
228 and partition[j] < partition[j - 1]
229 ):
230 new_partition = list(partition)
231 new_partition[i] -= 1
232 new_partition[j] += 1
233 new_partition = tuple(new_partition)
234 if new_partition not in cone:
235 new_partitions.add(new_partition)
236 cone.update(new_partitions)
237 prev_partitions = new_partitions
238 return cone
241def partition_dominance_or_subpartition_cone(
242 partition: Tuple[int, ...],
243) -> Set[Tuple[int, ...]]:
244 """
245 Calculates subpartitions and partitions dominated by a given one and having
246 the same number of parts (including zero parts of the original).
248 :param partition:
249 A partition.
251 :return:
252 A set of partitions.
253 """
254 cone = {partition}
255 new_partitions: Set[Tuple[int, ...]] = {(0,)}
256 prev_partitions = cone
257 while new_partitions:
258 new_partitions = set()
259 for partition in prev_partitions:
260 for i in range(len(partition) - 1):
261 if partition[i] > partition[i + 1]:
262 new_partition = list(partition)
263 new_partition[i] -= 1
264 new_partition = tuple(new_partition)
265 if new_partition not in cone:
266 new_partitions.add(new_partition)
267 for j in range(i + 1, len(partition)):
268 if (
269 partition[i] > partition[j] + 1
270 and partition[j] < partition[j - 1]
271 ):
272 new_partition = list(partition) # type: ignore[assignment]
273 new_partition[i] -= 1 # type: ignore[index]
274 new_partition[j] += 1 # type: ignore[index]
275 new_partition = tuple(new_partition)
276 if new_partition not in cone:
277 new_partitions.add(new_partition)
278 cone.update(new_partitions)
279 prev_partitions = new_partitions
280 return cone
283@contextmanager
284def get_resource_file_path(filename: str):
285 """
286 A contextmanager wrapper around `impresources` that supports both
287 Python>=3.9 and Python==3.8 with a unified interface.
289 :param filename:
290 The name of the file.
291 """
292 if sys.version_info >= (3, 9):
293 with impresources.as_file(impresources.files(resources) / filename) as path:
294 yield path
295 else:
296 with impresources.path(resources, filename) as path:
297 yield path
300def hamming_distance(x1: B.Int, x2: B.Int):
301 """
302 Hamming distance between two batches of boolean vectors.
304 :param x1:
305 Array of any backend, of shape [N, D].
306 :param x2:
307 Array of any backend, of shape [M, D].
309 :return:
310 An array of shape [N, M] whose entry n, m contains the Hamming distance
311 between x1[n, :] and x2[m, :]. It is of the same backend as x1 and x2.
312 """
313 return count_nonzero(x1[:, None, :] != x2[None, :, :], axis=-1)
316def log_binomial(n: B.Int, k: B.Int) -> B.Float:
317 """
318 Compute the logarithm of the binomial coefficient.
320 :param n:
321 The number of elements in the set.
322 :param k:
323 The number of elements to choose.
325 :return:
326 The logarithm of the binomial coefficient binom(n, k).
327 """
328 if not B.all(0 <= k <= n):
329 raise ValueError("Incorrect parameters of the binomial coefficient.")
331 return B.loggamma(n + 1) - B.loggamma(k + 1) - B.loggamma(n - k + 1)
334def binary_vectors_and_subsets(d: int):
335 r"""
336 Generates all possible binary vectors of size d and all possible subsets of
337 the set $\{0, .., d-1\}$ as a byproduct.
339 :param d:
340 The dimension of binary vectors and the size of the set to take subsets of.
342 :return:
343 A tuple (x, combs), where x is a matrix of size (2**d, d) whose rows
344 are all possible binary vectors of size d, and combs is a list of all
345 possible subsets of the set $\{0, .., d-1\}$, each subset being
346 represented by a list of integers itself.
347 """
348 x = np.zeros((2**d, d), dtype=bool)
349 combs = []
350 i = 0
351 for level in range(d + 1):
352 for cur_combination in combinations(range(d), level):
353 indices = list(cur_combination)
354 combs.append(indices)
355 x[i, indices] = 1
356 i += 1
358 return x, combs
361def _check_field_in_params(params, field):
362 """
363 Raise an error if `params` does not contain a `field`.
364 """
365 if field not in params:
366 raise ValueError(f"`params` must contain `{field}`.")
369def _check_1_vector(x, desc):
370 """
371 Raise an error if `x` is not a vector of shape [1,].
372 """
373 if B.shape(x) != (1,):
374 raise ValueError(
375 f"`{desc}` must have shape `[1,]`, but has shape {B.shape(x)}."
376 )
379def _check_rank_1_array(x, desc):
380 """
381 Raise an error if `x` is not a rank-1 array.
382 """
383 if B.rank(x) != 1:
384 raise ValueError(
385 f"`{desc}` must have 1 dimension (`ndim` == 1), but has shape {B.shape(x)}."
386 )
389def _check_matrix(x, desc):
390 """
391 Raise an error if `x` is not a matrix.
392 """
393 if B.rank(x) != 2:
394 raise ValueError(f"`{desc}` must be a matrix, but has shape {B.shape(x)}.")