Coverage for geometric_kernels/feature_maps/random_phase.py: 70%
50 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1r"""
2This module provides the random phase-based feature maps.
4Specifically, it provides a random phase-based feature map for
5:class:`~.spaces.DiscreteSpectrumSpace`\ s for which the
6:doc:`addition theorem </theory/addition_theorem>`-like basis functions
7are explicitly available while the actual eigenpairs may remain implicit.
9It also provides a basic random phase-based feature map for
10:class:`~.spaces.NoncompactSymmetricSpace`\ s. It should be used unless a more
11specialized per-space implementation is available, like the ones in the module
12:mod:`geometric_kernels.feature_maps.rejection_sampling`.
13"""
15import lab as B
16from beartype.typing import Dict, Tuple
18from geometric_kernels.feature_maps.base import FeatureMap
19from geometric_kernels.feature_maps.probability_densities import base_density_sample
20from geometric_kernels.lab_extras import complex_like, from_numpy, is_complex
21from geometric_kernels.spaces import DiscreteSpectrumSpace, NoncompactSymmetricSpace
24class RandomPhaseFeatureMapCompact(FeatureMap):
25 r"""
26 Random phase feature map for :class:`~.spaces.DiscreteSpectrumSpace`\ s for
27 which the :doc:`addition theorem </theory/addition_theorem>`-like basis
28 functions are explicitly available while actual eigenpairs may be implicit.
30 :param space:
31 A :class:`~.spaces.DiscreteSpectrumSpace` space.
32 :param num_levels:
33 Number of levels in the kernel approximation.
34 :param num_random_phases:
35 Number of random phases used in the generalized
36 random phase Fourier features technique.
37 """
39 def __init__(
40 self,
41 space: DiscreteSpectrumSpace,
42 num_levels: int,
43 num_random_phases: int = 3000,
44 ):
45 self.space = space
46 self.num_levels = num_levels
47 self.num_random_phases = num_random_phases
48 self.eigenfunctions = space.get_eigenfunctions(num_levels)
50 def __call__(
51 self,
52 X: B.Numeric,
53 params: Dict[str, B.Numeric],
54 *,
55 key: B.RandomState,
56 normalize: bool = True,
57 **kwargs,
58 ) -> Tuple[B.RandomState, B.Numeric]:
59 """
60 :param X:
61 [N, ...] points in the space to evaluate the map on.
63 :param params:
64 Parameters of the kernel (length scale and smoothness).
66 :param key:
67 Random state, either `np.random.RandomState`,
68 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which
69 represents a random state).
71 .. note::
72 For any backend other than `jax`, passing the same `key` twice
73 does not guarantee that the feature map will be the same each
74 time. This is because these backends' random state has... a
75 state. To evaluate the same (including randomness) feature map
76 on different inputs, you can either save/restore state manually
77 each time or use the helper function
78 :func:`~.utils.make_deterministic` which
79 does this for you.
81 :param normalize:
82 Normalize to have unit average variance. If omitted, set to True.
84 :param ``**kwargs``:
85 Unused.
87 :return:
88 `Tuple(key, features)` where `features` is an [N, O] array, N
89 is the number of inputs and O is the dimension of the feature map;
90 `key` is the updated random key for `jax`, or the similar random
91 state (generator) for any other backends.
92 """
93 from geometric_kernels.kernels.karhunen_loeve import MaternKarhunenLoeveKernel
95 key, random_phases = self.space.random(key, self.num_random_phases) # [O, D]
96 eigenvalues = self.space.get_eigenvalues(self.num_levels)
98 spectrum = MaternKarhunenLoeveKernel.spectrum(
99 eigenvalues,
100 nu=params["nu"],
101 lengthscale=params["lengthscale"],
102 dimension=self.space.dimension,
103 )
105 if is_complex(X):
106 dtype = complex_like(params["lengthscale"])
107 else:
108 dtype = B.dtype(params["lengthscale"])
110 weights = B.power(spectrum, 0.5) # [L, 1]
112 random_phases_b = B.cast(dtype, from_numpy(X, random_phases))
114 phi_product = self.eigenfunctions.phi_product(
115 X, random_phases_b, **params
116 ) # [N, O, L]
118 embedding = B.cast(dtype, phi_product) # [N, O, L]
119 weights_t = B.cast(dtype, B.transpose(weights))
121 features = B.reshape(embedding * weights_t, B.shape(X)[0], -1) # [N, O*L]
122 if is_complex(features):
123 features = B.concat(B.real(features), B.imag(features), axis=1)
125 if normalize:
126 normalizer = B.sqrt(B.sum(features**2, axis=-1, squeeze=False))
127 features = features / normalizer
129 return key, features
132class RandomPhaseFeatureMapNoncompact(FeatureMap):
133 r"""
134 Basic random phase feature map for
135 :class:`~.spaces.NoncompactSymmetricSpace`\ s (importance sampling based).
137 This feature map should not be used if a space-specific alternative exists.
139 :param space:
140 A :class:`~.spaces.NoncompactSymmetricSpace` space.
141 :param num_random_phases:
142 Number of random phases to use.
143 :param shifted_laplacian:
144 If True, assumes that the kernels are defined in terms of the shifted
145 Laplacian. This often makes Matérn kernels more flexible by widening
146 the effective range of the length scale parameter.
148 Defaults to True.
149 """
151 def __init__(
152 self,
153 space: NoncompactSymmetricSpace,
154 num_random_phases: int = 3000,
155 shifted_laplacian: bool = True,
156 ):
157 self.space = space
158 self.num_random_phases = num_random_phases
159 self.shifted_laplacian = shifted_laplacian
161 def __call__(
162 self,
163 X: B.Numeric,
164 params: Dict[str, B.Numeric],
165 *,
166 key: B.RandomState,
167 normalize: bool = True,
168 **kwargs,
169 ) -> Tuple[B.RandomState, B.Numeric]:
170 """
171 :param X:
172 [N, ...] points in the space to evaluate the map on.
173 :param params:
174 Parameters of the feature map (length scale and smoothness).
175 :param key:
176 Random state, either `np.random.RandomState`,
177 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which
178 represents a random state).
180 .. note::
181 For any backend other than `jax`, passing the same `key` twice
182 does not guarantee that the feature map will be the same each
183 time. This is because these backends' random state has... a
184 state. To evaluate the same (including randomness) feature map
185 on different inputs, you can either save/restore state manually
186 each time or use the helper function
187 :func:`~.utils.make_deterministic` which
188 does this for you.
190 :param normalize:
191 Normalize to have unit average variance (`True` by default).
192 :param ``**kwargs``:
193 Unused.
195 :return: `Tuple(key, features)` where `features` is an [N, O] array, N
196 is the number of inputs and O is the dimension of the feature map;
197 `key` is the updated random key for `jax`, or the similar random
198 state (generator) for any other backends.
199 """
201 key, random_phases = self.space.random_phases(
202 key, self.num_random_phases
203 ) # [O, <axes_p>]
205 key, random_lambda = base_density_sample(
206 key,
207 (self.num_random_phases,), # [O, 1]
208 params,
209 self.space.dimension,
210 self.space.rho,
211 self.shifted_laplacian,
212 ) # [O, P]
214 random_phases_b = B.expand_dims(
215 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_phases))
216 ) # [1, O, <axes_p>]
217 random_lambda_b = B.expand_dims(
218 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_lambda))
219 ) # [1, O, P]
220 X_b = B.expand_dims(X, axis=-1 - self.space.num_axes) # [N, 1, <axes>]
222 p = self.space.power_function(random_lambda_b, X_b, random_phases_b) # [N, O]
223 c = self.space.inv_harish_chandra(random_lambda_b) # [1, O]
225 features = B.concat(B.real(p) * c, B.imag(p) * c, axis=-1) # [N, 2*O]
226 if normalize:
227 normalizer = B.sqrt(B.sum(features**2, axis=-1, squeeze=False))
228 features = features / normalizer
230 return key, features