Coverage for geometric_kernels/spaces/product.py: 94%
159 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provides the :class:`ProductDiscreteSpectrumSpace` space and the
3respective :class:`~.eigenfunctions.Eigenfunctions` subclass
4:class:`ProductEigenfunctions`.
6See :doc:`this page </theory/product_spaces>` for a brief account on
7theory behind product spaces and the :doc:`Torus.ipynb </examples/Torus>`
8notebook for a tutorial on how to use them.
9"""
11import itertools
12import math
14import lab as B
15import numpy as np
16from beartype.typing import List, Optional, Tuple
18from geometric_kernels.lab_extras import from_numpy, int_like, take_along_axis
19from geometric_kernels.spaces.base import DiscreteSpectrumSpace
20from geometric_kernels.spaces.eigenfunctions import Eigenfunctions
21from geometric_kernels.utils.product import make_product, project_product
22from geometric_kernels.utils.utils import chain
25def _find_lowest_sum_combinations(arr: B.Numeric, k: int) -> B.Numeric:
26 """
27 For an [N, D] array `arr`, with `arr[i, :]` assumed sorted, find `k`
28 smallest sums of one element per each row, i.e. sums of form
29 `sum_i arr[i, j_i]`, return the array of indices of the summands.
31 :param arr:
32 An [N, D]-shaped array.
33 :param k:
34 Number of smallest sums to find.
36 :return:
37 A [k, N]-shaped array of the same backend as `arr`.
38 """
39 N, D = arr.shape
41 index_array = B.stack(*[B.zeros(int_like(arr), N) for i in range(k)], axis=0)
42 # prebuild array for indexing the first index of the eigenvalue array
43 first_index = B.linspace(0, N - 1, N).astype(int)
44 i = 0
46 # first eigenvalue is the sum of the first eigenvalues of the individual spaces
47 curr_idx = B.zeros(int, N)[None, :]
48 while i < k:
49 # compute eigenvalues of the proposals
50 sum_values = arr[first_index, curr_idx].sum(axis=1)
52 # Compute tied smallest new eigenvalues
53 lowest_sum_index = int(sum_values.argmin())
54 tied_sums = sum_values == sum_values[lowest_sum_index]
55 tied_sums_indexes = B.linspace(0, len(tied_sums) - 1, len(tied_sums)).astype(
56 int
57 )[tied_sums]
59 # Add new eigenvalues to indexing array
60 for index in tied_sums_indexes:
61 index_array[i, :] = curr_idx[index]
62 i += 1
63 if i >= k:
64 break
66 # keep unaccepted ones around
67 old_indices = curr_idx[~tied_sums]
68 # mutate just accepted ones by adding one to each eigenindex
69 new_indices = curr_idx[tied_sums][..., None, :] + B.eye(int, curr_idx.shape[-1])
70 new_indices = new_indices.reshape((-1, new_indices.shape[-1]))
71 new_indices = B.minimum(new_indices, D - 1)
72 curr_idx = B.concat(old_indices, new_indices, axis=0)
73 curr_idx = np.unique(
74 B.to_numpy(curr_idx.reshape((-1, curr_idx.shape[-1]))),
75 axis=0,
76 )
77 # Filter out already searched combinations.
78 # See https://stackoverflow.com/a/40056251.
79 dims = np.maximum(curr_idx.max(0), index_array.max(0)) + 1
80 curr_idx = curr_idx[
81 ~np.in1d(
82 np.ravel_multi_index(curr_idx.T, dims),
83 np.ravel_multi_index(index_array.T, dims),
84 )
85 ]
87 return index_array
90def _num_per_level_to_mapping(num_per_level: List[int]) -> List[List[int]]:
91 """
92 Given a list of numbers of eigenfunctions per level, return a list `mapping`
93 such that `mapping[i][j]` is the index of the `j`-th eigenfunction of the
94 `i`-th level.
96 :param num_per_level:
97 A `num_eigenfunctions_per_level` list of some space.
98 """
99 mapping = []
100 i = 0
101 for num in num_per_level:
102 mapping.append([i + j for j in range(num)])
103 i += num
104 return mapping
107def _eigenlevelindices_to_eigenfunctionindices(
108 eigenlevelindices: B.Numeric, nums_per_level: List[List[int]]
109) -> B.Numeric:
110 """
111 Given an `eigenlevelindices` which maps the product space's level indices to
112 the indices of the factor spaces' levels, return an `eigenfunctionindices`
113 array that maps the product space's eigenfunction index to the indices of
114 the factor spaces' individual eigenfunctions.
116 :param eigenlevelindices:
117 A [L, S]-shaped array, where `S` is the number of factor spaces and `L`
118 is the number of levels, such that `eigenindicies[i, :]` are the indices
119 of the levels of the factor spaces that correspond to the `i`'th level
120 of the product space.
121 :param nums_per_level:
122 Each `nums_per_level[j]` represents the `num_eigenfunctions_per_level`
123 list of the `j`-th factor space.
125 :return:
126 A [J, S]-shaped array, where `J` is the total number of eigenfunctions.
127 """
128 L = eigenlevelindices.shape[0]
129 S = eigenlevelindices.shape[1]
130 levels_to_eigenfunction_indices: List[List[List[int]]] = [
131 _num_per_level_to_mapping(npl) for npl in nums_per_level
132 ] # S lists of lists of length L
134 eigenfunctionindices = []
135 for cur_level in range(L):
136 cur_factor_eigenfunction_indices: List[List[int]] = [
137 levels_to_eigenfunction_indices[s][eigenlevelindices[cur_level, s]]
138 for s in range(S)
139 ]
140 new_indices: List[Tuple[int, ...]] = list(
141 itertools.product(*cur_factor_eigenfunction_indices)
142 ) # part of the resulting eigenfunctionindices for the cur_level level.
143 eigenfunctionindices += new_indices
145 out = from_numpy(eigenlevelindices, np.array(eigenfunctionindices))
146 return out
149class ProductEigenfunctions(Eigenfunctions):
150 r"""
151 Eigenfunctions of the Laplacian on the product space are outer products
152 of the eigenfunctions on factors.
154 Levels correspond to tuples of levels of the factors.
156 :param element_shapes:
157 Shapes of the elements in each factor. Can be obtained as properties
158 `space.element_shape` of any given factor `space`.
159 :param element_dtypes:
160 Abstract lab data types of the elements in each factor. Can be obtained
161 as properties `space.element_dtype` of any given factor `space`.
162 :param eigenindicies:
163 A [L, S]-shaped array, where `S` is the number of factor spaces and `L`
164 is the number of levels, such that `eigenindicies[i, :]` are the indices
165 of the levels of the factor spaces that correspond to the `i`'th level
166 of the product space.
168 This parameter implicitly determines the number of levels.
169 :param ``*eigenfunctions``:
170 The :class:`~.Eigenfunctions` subclasses corresponding to each of
171 the factor spaces.
172 :param dimension_indices:
173 Determines how a vector `x` representing a product space element is to
174 be mapped into the arrays `xi` that represent elements of the factor
175 spaces. `xi` are assumed to be equal to `x[dimension_indices[i]]`,
176 possibly up to a reshape. Such a reshape might be necessary to
177 accommodate the spaces whose elements are matrices rather than vectors,
178 as determined by `element_shapes`. The transformation of `x` into the
179 list of `xi`\ s is performed by :func:`~.project_product`.
181 If None, assumes the each input is layed-out flattened and concatenated,
182 in the same order as the factor spaces. In this case, the inverse to
183 :func:`~.project_product` is :func:`~.make_product`.
185 Defaults to None.
186 """
188 def __init__(
189 self,
190 element_shapes: List[List[int]],
191 element_dtypes: List[B.DType],
192 eigenindicies: B.Numeric,
193 *eigenfunctions: Eigenfunctions,
194 dimension_indices: B.Numeric = None,
195 ):
196 self._num_levels = eigenindicies.shape[0]
197 self.element_shapes = element_shapes
198 self.element_dtypes = element_dtypes
199 dimensions = [math.prod(element_shape) for element_shape in self.element_shapes]
200 if dimension_indices is None:
201 self.dimension_indices = []
202 i = 0
203 for dim in dimensions:
204 self.dimension_indices.append([*range(i, i + dim)])
205 i += dim
206 self.eigenindicies = eigenindicies
207 self.eigenfunctions = eigenfunctions
209 self.nums_per_level: List[List[int]] = [
210 eigenfunctions_factor.num_eigenfunctions_per_level
211 for eigenfunctions_factor in self.eigenfunctions
212 ] # [S, LFactor] where `LFactor` is eigenfunctions[0].num_levels
214 self._eigenfunctionindices = _eigenlevelindices_to_eigenfunctionindices(
215 self.eigenindicies, self.nums_per_level
216 )
218 if self.eigenindicies.shape[-1] != len(self.eigenfunctions):
219 raise ValueError(
220 "Expected to have S `eigenfunctions` and `eigenindicies` of shape [L, S], "
221 "where S is the number of spaces and L is the number of levels, "
222 f"but got S1={len(self.eigenfunctions)} eigenfunctions and "
223 f"the shape of `eigenindicies` is {self.eigenindicies.shape}, which is incompatible."
224 )
226 def __call__(self, X: B.Numeric, **kwargs) -> B.Numeric:
227 """
228 Evaluate the individual eigenfunctions at a batch of input locations.
230 .. warning::
231 Will `raise NotImplementedError` if some of the factors do not
232 implement their `__call__`.
234 :param X:
235 Points to evaluate the eigenfunctions at, an array of shape [N, D],
236 where `N` is the number of points and `D` is the dimension of the
237 vectors that represent points in the product space.
239 Each point `x` in the product space is a `D`-dimensional vector such
240 that `x[dimension_indices[i]]` is the vector that represents the
241 flattened element of the `i`-th factor space.
243 .. note::
244 If the instance was created with `dimension_indices=None`, then
245 you can use the :func:`~.make_product` function to convert a
246 list of (batches of) points in factor spaces into a (batch of)
247 points in the product space.
249 :param ``**kwargs``:
250 Any additional parameters.
252 :return:
253 An [N, J]-shaped array, where `J` is the number of eigenfunctions.
254 """
255 Xs = project_product(
256 X, self.dimension_indices, self.element_shapes, self.element_dtypes
257 ) # List of S arrays, each of shape [N, *element_shape_s]
259 factor_eigenfunction_values = [
260 eigenfunction(X, **kwargs) # [N, Js], Js different for each factor
261 for eigenfunction, X in zip(self.eigenfunctions, Xs)
262 ] # List of length S, contains arrays of different shapes (not stackable)
264 eigenfunction_values = []
265 for j in range(self.num_eigenfunctions):
266 m_idx = self._eigenfunctionindices[j] # (S,)-shaped array
267 eigenfunction_values.append(
268 B.prod(
269 B.stack(
270 *(
271 factor_eigenfunction_values[s][
272 :, self._eigenfunctionindices[j, s]
273 ] # [N,]
274 for s in range(len(m_idx))
275 ),
276 axis=-1,
277 ), # [N, S]
278 axis=1,
279 ) # [N,]
280 )
281 eigenfunction_values = B.stack(*eigenfunction_values, axis=-1) # [N, J]
283 return eigenfunction_values
285 @property
286 def num_eigenfunctions(self) -> int:
287 return self._eigenfunctionindices.shape[0]
289 @property
290 def num_levels(self) -> int:
291 return self._num_levels
293 def phi_product(
294 self, X: B.Numeric, X2: Optional[B.Numeric] = None, **kwargs
295 ) -> B.Numeric:
296 if X2 is None:
297 X2 = X
298 Xs = project_product(
299 X, self.dimension_indices, self.element_shapes, self.element_dtypes
300 ) # List of S arrays, each of shape [N, *element_shape_s]
301 Xs2 = project_product(
302 X2, self.dimension_indices, self.element_shapes, self.element_dtypes
303 ) # List of S arrays, each of shape [N2, *element_shape_s]
305 factor_phi_products = [
306 take_along_axis(
307 eigenfunction.phi_product(X1, X2, **kwargs), # [N, N2, L_s]
308 from_numpy(X1, self.eigenindicies[None, None, :, s]),
309 -1,
310 ) # [N, N2, L]
311 for s, (eigenfunction, X1, X2) in enumerate(
312 zip(self.eigenfunctions, Xs, Xs2)
313 )
314 ]
315 common_dtype = B.promote_dtypes(*[B.dtype(x) for x in factor_phi_products])
317 phis = B.stack(
318 *[B.cast(common_dtype, x) for x in factor_phi_products], axis=-1
319 ) # [N, N2, L, S]
321 prod_phis = B.prod(phis, axis=-1) # [N, N2, L, S] -> [N, N2, L]
323 return prod_phis
325 def phi_product_diag(self, X: B.Numeric, **kwargs):
326 Xs = project_product(
327 X, self.dimension_indices, self.element_shapes, self.element_dtypes
328 )
330 phis = B.stack(
331 *[
332 take_along_axis(
333 eigenfunction.phi_product_diag(X1, **kwargs), # [N, L_s]
334 from_numpy(X1, self.eigenindicies[None, :, s]),
335 -1,
336 ) # [N, L]
337 for s, (eigenfunction, X1) in enumerate(zip(self.eigenfunctions, Xs))
338 ],
339 axis=-1,
340 ) # [N, L, S]
342 prod_phis = B.prod(phis, axis=-1) # [N, L]
344 return prod_phis
346 @property
347 def num_eigenfunctions_per_level(self) -> List[int]:
348 L = self.eigenindicies.shape[0]
349 S = self.eigenindicies.shape[1]
351 totals = []
353 for cur_level in range(L):
354 totals.append(
355 math.prod(
356 self.nums_per_level[s][self.eigenindicies[cur_level, s]]
357 for s in range(S)
358 )
359 )
361 return totals
364class ProductDiscreteSpectrumSpace(DiscreteSpectrumSpace):
365 r"""
366 The GeometricKernels space representing a product
368 .. math:: \mathcal{S} = \mathcal{S}_1 \times \ldots \mathcal{S}_S
370 of :class:`~.spaces.DiscreteSpectrumSpace`-s $\mathcal{S}_i$.
372 Levels are indexed by tuples of levels of the factors.
374 .. admonition:: Precomputing optimal levels
376 The eigenvalue corresponding to a level of the product space
377 represented by a tuple $(j_1, .., j_S)$ is the sum
379 .. math:: \lambda^{(1)}_{j_1} + \ldots + \lambda^{(S)}_{j_S}
381 of the eigenvalues $\lambda^{(s)}_{j_s}$ of the factors.
383 Computing top `num_levels` smallest eigenvalues is thus a combinatorial
384 problem, the solution to which we precompute. To make this
385 precomputation possible you need to provide the `num_levels` parameter
386 when constructing the :class:`ProductDiscreteSpectrumSpace`, unlike for
387 other spaces. What is more, the `num` parameter of the
388 :meth:`get_eigenfunctions` cannot be larger than the `num_levels` value.
390 .. note::
391 See :doc:`this page </theory/product_spaces>` for a brief account on
392 theory behind product spaces and the :doc:`Torus.ipynb
393 </examples/Torus>` notebook for a tutorial on how to use them.
395 An alternative to using :class:`ProductDiscreteSpectrumSpace` is to use
396 the :class:`~.kernels.ProductGeometricKernel` kernel. The latter is more
397 flexible, allowing more general spaces as factors and automatic
398 relevance determination -like behavior.
400 :param spaces:
401 The factors, subclasses of :class:`~.spaces.DiscreteSpectrumSpace`.
402 :param num_levels:
403 The number of levels to pre-compute for this product space.
404 :param num_levels_per_space:
405 Number of levels to fetch for each of the factor spaces, to compute
406 the product-space levels. This is a single number rather than a list,
407 because we currently only support fetching the same number of levels
408 for all the factor spaces.
410 If not given, `num_levels` levels will be fetched for each factor.
412 .. admonition:: Citation
414 If you use this GeometricKernels space in your research, please consider
415 citing :cite:t:`borovitskiy2020`.
416 """
418 def __init__(
419 self,
420 *spaces: DiscreteSpectrumSpace,
421 num_levels: int = 25,
422 num_levels_per_space: Optional[int] = None,
423 ):
424 for space in spaces:
425 if not isinstance(space, DiscreteSpectrumSpace):
426 raise ValueError(
427 "One of the spaces is not an instance of DiscreteSpectrumSpace."
428 )
430 self.factor_spaces = spaces # List of length S
431 self.num_levels = num_levels
433 # Perform a breadth-first search for the smallest eigenvalues.
434 # Assuming that the eigenvalues come sorted, the next biggest eigenvalue
435 # can be found by taking a one-index step in any direction from the
436 # current edge of the search space.
438 if num_levels_per_space is None:
439 num_levels_per_space = num_levels
440 if num_levels > num_levels_per_space ** len(spaces):
441 raise ValueError(
442 "Cannot have more levels than there are possible combinations."
443 )
445 # prefetch the eigenvalues of the subspaces
446 factor_space_eigenvalues = B.stack(
447 *[
448 space.get_eigenvalues(num_levels_per_space)[:, 0]
449 for space in self.factor_spaces
450 ],
451 axis=0,
452 ) # [S, num_levels_per_space]
454 self.factor_space_eigenindices = _find_lowest_sum_combinations(
455 factor_space_eigenvalues, self.num_levels
456 ) # [self.num_levels, S]
457 self.factor_space_eigenvalues = factor_space_eigenvalues
459 self._eigenvalues = factor_space_eigenvalues[
460 range(len(self.factor_spaces)),
461 self.factor_space_eigenindices,
462 ].sum(axis=1)
464 def __str__(self):
465 return f"ProductDiscreteSpectrumSpace({', '.join(str(space) for space in self.factor_spaces)})"
467 @property
468 def dimension(self) -> int:
469 """
470 Returns the dimension of the product space, equal to the sum of
471 dimensions of the factors.
472 """
473 return sum([space.dimension for space in self.factor_spaces])
475 def random(self, key, number):
476 """
477 Sample random points on the product space by concatenating random points
478 on the factor spaces via :func:`~.make_product`.
480 :param key:
481 Either `np.random.RandomState`, `tf.random.Generator`,
482 `torch.Generator` or `jax.tensor` (representing random state).
483 :param number:
484 Number of samples to draw.
486 :return:
487 An array of `number` uniformly random samples on the space.
488 """
489 random_points = []
490 for factor in self.factor_spaces:
491 key, factor_random_points = factor.random(key, number)
492 random_points.append(factor_random_points)
494 return key, make_product(random_points)
496 def get_eigenfunctions(self, num: int) -> Eigenfunctions:
497 """
498 Returns the :class:`~.ProductEigenfunctions` object with `num` levels.
500 :param num:
501 Number of levels. Cannot be larger than the `num_levels` parameter
502 of the constructor.
503 """
504 if num > self.num_levels:
505 raise ValueError(
506 "`num` cannot be larger than the `num_levels` provided in the constructor."
507 )
509 max_level = int(self.factor_space_eigenindices[:num, :].max() + 1)
511 factor_space_eigenfunctions = [
512 space.get_eigenfunctions(max_level) for space in self.factor_spaces
513 ]
515 return ProductEigenfunctions(
516 [space.element_shape for space in self.factor_spaces],
517 [space.element_dtype for space in self.factor_spaces],
518 self.factor_space_eigenindices[:num],
519 *factor_space_eigenfunctions,
520 )
522 def get_eigenvalues(self, num: int) -> B.Numeric:
523 """
524 Eigenvalues of the Laplacian corresponding to the first `num` levels.
526 :param num:
527 Number of levels. Cannot be larger than the `num_levels` parameter
528 of the constructor.
529 :return:
530 (num, 1)-shaped array containing the eigenvalues.
531 """
532 if num > self.num_levels:
533 raise ValueError(
534 "`num` cannot be larger than the `num_levels` provided in the constructor."
535 )
537 return self._eigenvalues[:num, None]
539 def get_repeated_eigenvalues(self, num: int) -> B.Numeric:
540 """
541 Eigenvalues of the Laplacian corresponding to the first `num` levels,
542 repeated according to their multiplicity within levels.
544 :param num:
545 Number of levels. Cannot be larger than the `num_levels` parameter
546 of the constructor.
548 :return:
549 (J, 1)-shaped array containing the repeated eigenvalues,`J is
550 the resulting number of the repeated eigenvalues.
551 """
552 if num > self.num_levels:
553 raise ValueError(
554 "`num` cannot be larger than the `num_levels` provided in the constructor."
555 )
557 eigenfunctions = self.get_eigenfunctions(num)
558 eigenvalues = self._eigenvalues[:num]
559 multiplicities = eigenfunctions.num_eigenfunctions_per_level
561 repeated_eigenvalues = chain(eigenvalues, multiplicities)
562 return B.reshape(repeated_eigenvalues, -1, 1) # [J, 1]
564 @property
565 def element_shape(self):
566 """
567 :return:
568 Sum of the products of the element shapes of the factor spaces.
569 """
570 return sum(math.prod(space.element_shape) for space in self.factor_spaces)
572 @property
573 def element_dtype(self):
574 """
575 :return:
576 B.Numeric.
577 """
578 return B.Numeric