Coverage for geometric_kernels/sampling/samplers.py: 92%
25 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2Samplers.
3"""
5from __future__ import annotations # By https://stackoverflow.com/a/62136491
7from functools import partial
9import lab as B
10from beartype.typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
12# By https://stackoverflow.com/a/62136491
13if TYPE_CHECKING:
14 from geometric_kernels.feature_maps import FeatureMap
17def sample_at(
18 feature_map: FeatureMap,
19 s: int,
20 X: B.Numeric,
21 params: Dict[str, B.Numeric],
22 key: B.RandomState = None,
23 normalize: bool = None,
24 hodge_type: Optional[str] = None,
25) -> Tuple[B.RandomState, B.Numeric]:
26 r"""
27 Given a `feature_map` $\phi_{\nu, \kappa}: X \to \mathbb{R}^n$, where
28 $\nu, \kappa$ are determined by `params["nu"]` and `params["lengthscale"]`,
29 respectively, compute `s` samples of the Gaussian process with kernel
31 .. math :: k_{\nu, \kappa}(x, y) = \langle \phi_{\nu, \kappa}(x), \phi_{\nu, \kappa}(y) \rangle_{\mathbb{R}^n}
33 at input locations `X` and using the random state `key`.
35 Generating a sample from $GP(0, k_{\nu, \kappa})$ is as simple as computing
37 .. math :: \sum_{j=1}^n w_j \cdot (\phi_{\nu, \kappa}(x))_j \qquad w_j \stackrel{IID}{\sim} N(0, 1).
39 .. note::
40 Fixing $w_j$, and treating $x \to (\phi_{\nu, \kappa}(x))_j$ as basis
41 functions, while letting $x$ vary, you get an actual *function* as a
42 sample, meaning something that can be evaluated at any $x \in X$.
43 The way to fix $w_j$ in code is to apply :func:`~.make_deterministic`
44 utility function.
46 :param feature_map:
47 The feature map $\phi_{\nu, \kappa}$ that defines the Gaussian process
48 $GP(0, k_{\nu, \kappa})$ to sample from.
49 :param s:
50 The number of samples to generate.
51 :param X:
52 An [N, <axis>]-shaped array containing N elements of the space
53 `feature_map` is defined on. <axis> is the shape of these elements.
54 These are the points to evaluate the samples at.
55 :param params:
56 Parameters of the kernel (length scale and smoothness).
57 :param key: random state, either `np.random.RandomState`,
58 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which
59 represents a random state).
60 :param normalize:
61 Passed down to `feature_map` directly. Controls whether to force the
62 average variance of the Gaussian process to be around 1 or not. If None,
63 follows the standard behavior, typically same as normalize=True.
65 Defaults to None.
66 :param hodge_type:
67 The type of Hodge component to sample. Only used when using Hodge-compositional edge kernels
69 Defaults to None.
71 :return:
72 [N, s]-shaped array containing s samples of the $GP(0, k_{\nu, \kappa})$
73 evaluated at X.
74 """
76 if key is None:
77 key = B.global_random_state(B.dtype(X))
79 if hodge_type is None:
80 _context, features = feature_map(
81 X, params, key=key, normalize=normalize
82 ) # [N, M]
83 else:
84 _context, features = feature_map(
85 X, params, key=key, normalize=normalize, hodge_type=hodge_type
86 )
88 if _context is not None:
89 key = _context
91 num_features = B.shape(features)[-1]
93 if "lengthscale" in params:
94 dtype = B.dtype(params["lengthscale"])
95 else:
96 dtype = B.dtype(params["gradient"]["lengthscale"])
98 key, random_weights = B.randn(key, dtype, num_features, s) # [M, S]
100 random_sample = B.matmul(features, random_weights) # [N, S]
102 return key, random_sample
105def sampler(
106 feature_map: FeatureMap, s: Optional[int] = 1, **kwargs
107) -> Callable[[Any], Any]:
108 """
109 A helper wrapper around `sample_at` that fixes `feature_map`, `s` and the
110 keyword arguments in ``**kwargs`` but leaves `X`, `params` and the other
111 keyword arguments vary.
113 :param feature_map:
114 The feature map to fix.
115 :param s:
116 The number of samples parameter to fix.
118 Defaults to 1.
119 :param ``**kwargs``:
120 Keyword arguments to fix.
122 :return:
123 The version of :func:`sample_at` with parameters `feature_map` and `s`
124 fixed, together with the keyword arguments in ``**kwargs``.
125 """
127 sample_f = partial(sample_at, feature_map, s, **kwargs)
128 new_docstring = f"""
129 This is a version of the `{sample_at.__name__}` function with
130 - feature_map={feature_map},
131 - s={s},
132 - and additional keyword arguments {kwargs}.
134 The original docstring follows.
136 {sample_at.__doc__}
137 """
138 sample_f.__name__ = sample_at.__name__ # type: ignore[attr-defined]
139 sample_f.__doc__ = new_docstring
141 return sample_f