Coverage for geometric_kernels/feature_maps/rejection_sampling.py: 100%
40 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provides the :class:`RejectionSamplingFeatureMapHyperbolic` and the
3:class:`RejectionSamplingFeatureMapSPD`, rejection sampling-based feature maps
4for :class:`~.spaces.Hyperbolic` and
5:class:`~.spaces.SymmetricPositiveDefiniteMatrices`, respectively.
6"""
8import lab as B
9from beartype.typing import Dict, Tuple
11from geometric_kernels.feature_maps.base import FeatureMap
12from geometric_kernels.feature_maps.probability_densities import (
13 hyperbolic_density_sample,
14 spd_density_sample,
15)
16from geometric_kernels.lab_extras import from_numpy
17from geometric_kernels.spaces import Hyperbolic, SymmetricPositiveDefiniteMatrices
20class RejectionSamplingFeatureMapHyperbolic(FeatureMap):
21 """
22 Random phase feature map for the :class:`~.spaces.Hyperbolic` space based
23 on the rejection sampling algorithm.
25 :param space:
26 A :class:`~.spaces.Hyperbolic` space.
27 :param num_random_phases:
28 Number of random phases to use.
29 :param shifted_laplacian:
30 If True, assumes that the kernels are defined in terms of the shifted
31 Laplacian. This often makes Matérn kernels more flexible by widening
32 the effective range of the length scale parameter.
34 Defaults to True.
35 """
37 def __init__(
38 self,
39 space: Hyperbolic,
40 num_random_phases: int = 3000,
41 shifted_laplacian: bool = True,
42 ):
43 self.space = space
44 self.num_random_phases = num_random_phases
45 self.shifted_laplacian = shifted_laplacian
47 def __call__(
48 self,
49 X: B.Numeric,
50 params: Dict[str, B.Numeric],
51 *,
52 key: B.RandomState,
53 normalize: bool = True,
54 **kwargs,
55 ) -> Tuple[B.RandomState, B.Numeric]:
56 """
57 :param X:
58 [N, D] points in the space to evaluate the map on.
59 :param params:
60 Parameters of the feature map (length scale and smoothness).
61 :param key:
62 Random state, either `np.random.RandomState`,
63 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which
64 represents a random state).
66 .. note::
67 For any backend other than `jax`, passing the same `key` twice
68 does not guarantee that the feature map will be the same each
69 time. This is because these backends' random state has... a
70 state. To evaluate the same (including randomness) feature map
71 on different inputs, you can either save/restore state manually
72 each time or use the helper function
73 :func:`~.utils.make_deterministic` which does this for you.
75 :param normalize:
76 Normalize to have unit average variance (`True` by default).
77 :param ``**kwargs``:
78 Unused.
80 :return:
81 `Tuple(key, features)` where `features` is an [N, O] array, N
82 is the number of inputs and O is the dimension of the feature map;
83 `key` is the updated random key for `jax`, or the similar random
84 state (generator) for any other backends.
85 """
87 key, random_phases = self.space.random_phases(
88 key, self.num_random_phases
89 ) # [O, D]
91 key, random_lambda = hyperbolic_density_sample(
92 key,
93 (self.num_random_phases, B.rank(self.space.rho)),
94 params,
95 self.space.dimension,
96 self.shifted_laplacian,
97 ) # [O, 1]
99 random_phases_b = B.expand_dims(
100 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_phases))
101 ) # [1, O, D]
102 random_lambda_b = B.expand_dims(
103 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_lambda))
104 ) # [1, O, 1]
105 X_b = B.expand_dims(X, axis=-2) # [N, 1, D]
107 p = self.space.power_function(random_lambda_b, X_b, random_phases_b) # [N, O]
109 features = B.concat(B.real(p), B.imag(p), axis=-1) # [N, 2*O]
110 if normalize:
111 normalizer = B.sqrt(B.sum(features**2, axis=-1, squeeze=False))
112 features = features / normalizer
114 return key, features
117class RejectionSamplingFeatureMapSPD(FeatureMap):
118 """
119 Random phase feature map for the
120 :class:`~.spaces.SymmetricPositiveDefiniteMatrices` space based on the
121 rejection sampling algorithm.
123 :param space:
124 A :class:`~.spaces.SymmetricPositiveDefiniteMatrices` space.
125 :param num_random_phases:
126 Number of random phases to use.
127 :param shifted_laplacian:
128 If True, assumes that the kernels are defined in terms of the shifted
129 Laplacian. This often makes Matérn kernels more flexible by widening
130 the effective range of the length scale parameter.
132 Defaults to True.
133 """
135 def __init__(
136 self,
137 space: SymmetricPositiveDefiniteMatrices,
138 num_random_phases: int = 3000,
139 shifted_laplacian: bool = True,
140 ):
141 self.space = space
142 self.num_random_phases = num_random_phases
143 self.shifted_laplacian = shifted_laplacian
145 def __call__(
146 self,
147 X: B.Numeric,
148 params: Dict[str, B.Numeric],
149 *,
150 key: B.RandomState,
151 normalize: bool = True,
152 **kwargs,
153 ) -> Tuple[B.RandomState, B.Numeric]:
154 """
155 :param X:
156 [N, D, D] points in the space to evaluate the map on.
157 :param params:
158 Parameters of the feature map (length scale and smoothness).
159 :param key:
160 Random state, either `np.random.RandomState`,
161 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which
162 represents a random state).
164 .. note::
165 For any backend other than `jax`, passing the same `key` twice
166 does not guarantee that the feature map will be the same each
167 time. This is because these backends' random state has... a
168 state. To evaluate the same (including randomness) feature map
169 on different inputs, you can either save/restore state manually
170 each time or use the helper function
171 :func:`~.utils.make_deterministic` which does this for you.
173 :param normalize:
174 Normalize to have unit average variance (`True` by default).
175 :param ``**kwargs``:
176 Unused.
178 :return:
179 `Tuple(key, features)` where `features` is an [N, O] array, N
180 is the number of inputs and O is the dimension of the feature map;
181 `key` is the updated random key for `jax`, or the similar random
182 state (generator) for any other backends.
183 """
185 key, random_phases = self.space.random_phases(
186 key, self.num_random_phases
187 ) # [O, D, D]
189 key, random_lambda = spd_density_sample(
190 key,
191 (self.num_random_phases,),
192 params,
193 self.space.degree,
194 self.space.rho,
195 self.shifted_laplacian,
196 ) # [O, D]
198 random_phases_b = B.expand_dims(
199 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_phases))
200 ) # [1, O, D, D]
201 random_lambda_b = B.expand_dims(
202 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_lambda))
203 ) # [1, O, D]
204 X_b = B.expand_dims(X, axis=-3) # [N, 1, D, D]
206 p = self.space.power_function(random_lambda_b, X_b, random_phases_b) # [N, O]
208 features = B.concat(B.real(p), B.imag(p), axis=-1) # [N, 2*O]
209 if normalize:
210 normalizer = B.sqrt(B.sum(features**2, axis=-1, squeeze=False))
211 features = features / normalizer
213 return key, features