Coverage for geometric_kernels/kernels/product.py: 55%
60 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provides the :class:`ProductGeometricKernel` kernel for
3constructing product kernels from a sequence of kernels.
5See :doc:`this page </theory/product_kernels>` for a brief account on
6theory behind product kernels and the :doc:`Torus.ipynb </examples/Torus>`
7notebook for a tutorial on how to use them.
8"""
10import math
12import lab as B
13from beartype.typing import Dict, List, Optional
15from geometric_kernels.kernels.base import BaseGeometricKernel
16from geometric_kernels.spaces import Space
17from geometric_kernels.utils.product import params_to_params_list, project_product
20class ProductGeometricKernel(BaseGeometricKernel):
21 r"""
22 Product kernel, defined as the product of a sequence of kernels.
24 See :doc:`this page </theory/product_kernels>` for a brief account on
25 theory behind product kernels and the :doc:`Torus.ipynb </examples/Torus>`
26 notebook for a tutorial on how to use them.
28 :param ``*kernels``:
29 A sequence of kernels to compute the product of. Cannot contain another
30 instance of :class:`ProductGeometricKernel`. We denote the number of
31 factors, i.e. the length of the "sequence", by s.
32 :param dimension_indices:
33 Determines how a product kernel input vector `x` is to be mapped into
34 the inputs `xi` for the factor kernels. `xi` are assumed to be equal to
35 `x[dimension_indices[i]]`, possibly up to a reshape. Such a reshape
36 might be necessary to accommodate the spaces whose elements are matrices
37 rather than vectors, as determined by `element_shapes`. The
38 transformation of `x` into the list of `xi`\ s is performed
39 by :func:`~.project_product`.
41 If None, assumes the each input is layed-out flattened and concatenated,
42 in the same order as the factor spaces. In this case, the inverse to
43 :func:`~.project_product` is :func:`~.make_product`.
45 Defaults to None.
47 .. note::
48 `params` of a :class:`ProductGeometricKernel` are such that
49 `params["lengthscale"]` and `params["nu"]` are (s,)-shaped arrays, where
50 `s` is the number of factors.
52 Basically, `params["lengthscale"][i]` stores the length scale parameter
53 for the `i`-th factor kernel. Same goes for `params["nu"]`. Importantly,
54 this enables *automatic relevance determination*-like behavior.
55 """
57 def __init__(
58 self,
59 *kernels: BaseGeometricKernel,
60 dimension_indices: Optional[List[List[int]]] = None,
61 ):
62 self.kernels = kernels
63 self.spaces: List[Space] = []
64 for kernel in self.kernels:
65 # Make sure there is no product kernel in the list of kernels.
66 if not isinstance(kernel.space, Space): # as opposed to List[Space]
67 raise NotImplementedError(
68 "One of the provided kernels is a product kernel itself."
69 )
70 self.spaces.append(kernel.space)
71 self.element_shapes = [space.element_shape for space in self.spaces]
72 self.element_dtypes = [space.element_dtype for space in self.spaces]
74 if dimension_indices is None:
75 dimensions = [math.prod(shape) for shape in self.element_shapes]
76 self.dimension_indices: List[List[int]] = []
77 i = 0
78 inds = [*range(sum(dimensions))]
79 for dim in dimensions:
80 self.dimension_indices.append(inds[i : i + dim])
81 i += dim
82 else:
83 if len(dimension_indices) != len(self.kernels):
84 raise ValueError(
85 f"`dimension_indices` must correspond to `kernels`, but got {len(kernels)} kernels and {len(dimension_indices)} dimension indices."
86 )
87 for idx_list in dimension_indices:
88 for idx in idx_list:
89 if idx < 0:
90 raise ValueError(
91 "Expected all `dimension_indices` to be non-negative."
92 )
94 self.dimension_indices = dimension_indices
96 @property
97 def space(self) -> List[Space]:
98 """
99 The list of spaces upon which the factor kernels are defined.
100 """
101 return self.spaces
103 def init_params(self) -> Dict[str, B.NPNumeric]:
104 r"""
105 Returns a dict `params` where `params["lengthscale"]` is the
106 concatenation of all `self.kernels[i].init_params()["lengthscale"]` and
107 same for `params["nu"]`.
108 """
109 nu_list: List[B.NPNumeric] = []
110 lengthscale_list: List[B.NPNumeric] = []
112 for kernel_idx, kernel in enumerate(self.kernels):
113 cur_params = kernel.init_params()
114 if B.shape(cur_params["lengthscale"]) != (1,):
115 raise ValueError(
116 f"All kernels' `lengthscale`s must be have shape [1,], but {kernel_idx}th kernel "
117 f"({kernel}) violates this with shape {B.shape(cur_params['lengthscale'])}."
118 )
119 if B.shape(cur_params["nu"]) != (1,):
120 raise ValueError(
121 f"All kernels' `nu`s must be have [1,], but {kernel_idx}th kernel "
122 f"({kernel}) violates this with shape {B.shape(cur_params['nu'])}."
123 )
125 nu_list.append(cur_params["nu"])
126 lengthscale_list.append(cur_params["lengthscale"])
128 params: Dict[str, B.NPNumeric] = {}
129 params["nu"] = B.concat(*nu_list)
130 params["lengthscale"] = B.concat(*lengthscale_list)
131 return params
133 def K(self, params: Dict[str, B.Numeric], X, X2=None, **kwargs) -> B.Numeric:
134 if X2 is None:
135 X2 = X
137 Xs = project_product(
138 X, self.dimension_indices, self.element_shapes, self.element_dtypes
139 )
140 X2s = project_product(
141 X2, self.dimension_indices, self.element_shapes, self.element_dtypes
142 )
143 params_list = params_to_params_list(len(self.kernels), params)
145 return B.prod(
146 B.stack(
147 *[
148 kernel.K(p, X, X2)
149 for kernel, X, X2, p in zip(self.kernels, Xs, X2s, params_list)
150 ],
151 axis=-1,
152 ),
153 axis=-1,
154 )
156 def K_diag(self, params, X):
157 Xs = project_product(
158 X, self.dimension_indices, self.element_shapes, self.element_dtypes
159 )
160 params_list = params_to_params_list(len(self.kernels), params)
162 return B.prod(
163 B.stack(
164 *[
165 kernel.K_diag(p, X)
166 for kernel, X, p in zip(self.kernels, Xs, params_list)
167 ],
168 axis=-1,
169 ),
170 axis=-1,
171 )