Coverage for geometric_kernels/kernels/base.py: 79%
19 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provides the :class:`BaseGeometricKernel` kernel, the base class
3for all geometric kernels.
4"""
6import abc
8import lab as B
9from beartype.typing import Dict, List, Optional, Union
11from geometric_kernels.spaces import Space
14class BaseGeometricKernel(abc.ABC):
15 """
16 Abstract base class for geometric kernels.
18 :param space:
19 The space on which the kernel is defined.
20 """
22 def __init__(self, space: Space):
23 self._space = space
25 @property
26 def space(self) -> Union[Space, List[Space]]:
27 """
28 The space on which the kernel is defined.
29 """
30 return self._space
32 @abc.abstractmethod
33 def init_params(self) -> Dict[str, B.NPNumeric]:
34 """
35 Initializes the dict of the trainable parameters of the kernel.
37 It typically contains only two keys: `"nu"` and `"lengthscale"`.
39 This dict can be modified and is passed around into such methods as
40 :meth:`~.K` or :meth:`~.K_diag`, as the `params` argument.
42 .. note::
43 The values in the returned dict are always of the NumPy array type.
44 Thus, if you want to use some other backend for internal
45 computations when calling :meth:`~.K` or :meth:`~.K_diag`, you
46 need to replace the values with the analogs typed as arrays of
47 the desired backend.
48 """
49 raise NotImplementedError
51 @abc.abstractmethod
52 def K(
53 self,
54 params: Dict[str, B.Numeric],
55 X: B.Numeric,
56 X2: Optional[B.Numeric] = None,
57 **kwargs,
58 ) -> B.Numeric:
59 """
60 Compute the cross-covariance matrix between two batches of vectors of
61 inputs, or batches of matrices of inputs, depending on the space.
63 :param params:
64 A dict of kernel parameters, typically containing two keys:
65 `"lengthscale"` for length scale and `"nu"` for smoothness.
67 The types of values in the params dict determine the output type
68 and the backend used for the internal computations, see the
69 warning below for more details.
71 .. note::
72 The values `params["lengthscale"]` and `params["nu"]` are
73 typically (1,)-shaped arrays of the suitable backend. This
74 serves to point at the backend to be used for internal
75 computations.
77 In some cases, for example, when the kernel is
78 :class:`~.kernels.ProductGeometricKernel`, the values of
79 `params` may be (s,)-shaped arrays instead, where `s` is the
80 number of factors.
82 .. note::
83 Finite values of `params["nu"]` typically correspond to the
84 generalized (geometric) Matérn kernels.
86 Infinite `params["nu"]` typically corresponds to the heat
87 kernel (a.k.a. diffusion kernel, generalized squared
88 exponential kernel, generalized Gaussian kernel,
89 generalized RBF kernel). Although it is often considered to be
90 a separate entity, we treat the heat kernel as a member of
91 the Matérn family, with smoothness parameter equal to infinity.
93 :param X:
94 A batch of N inputs, each of which is a vector or a matrix,
95 depending on how the elements of the `self.space` are represented.
96 :param X2:
97 A batch of M inputs, each of which is a vector or a matrix,
98 depending on how the elements of the `self.space` are represented.
100 `X2=None` sets `X2=X1`.
102 Defaults to None.
104 :return:
105 The N x M cross-covariance matrix.
107 .. warning::
108 The types of values in the `params` dict determine the backend
109 used for internal computations and the output type.
111 Even if, say, `geometric_kernels.jax` is imported but the values in
112 the `params` dict are NumPy arrays, the output type will be a NumPy
113 array, and NumPy will be used for internal computations. To get a
114 JAX array as an output and use JAX for internal computations, all
115 the values in the `params` dict must be JAX arrays.
116 """
117 raise NotImplementedError
119 @abc.abstractmethod
120 def K_diag(self, params: Dict[str, B.Numeric], X: B.Numeric, **kwargs) -> B.Numeric:
121 """
122 Returns the diagonal of the covariance matrix `self.K(params, X, X)`,
123 typically in a more efficient way than actually computing the full
124 covariance matrix with `self.K(params, X, X)` and then extracting its
125 diagonal.
127 :param params:
128 Same as for :meth:`~.K`.
130 :param X:
131 A batch of N inputs, each of which is a vector or a matrix,
132 depending on how the elements of the `self.space` are represented.
134 :return:
135 The N-dimensional vector representing the diagonal of the
136 covariance matrix `self.K(params, X, X)`.
137 """
138 raise NotImplementedError