Source code for geometric_kernels.sampling.samplers

"""
Samplers.
"""

from __future__ import annotations  # By https://stackoverflow.com/a/62136491

from functools import partial

import lab as B
from beartype.typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

# By https://stackoverflow.com/a/62136491
if TYPE_CHECKING:
    from geometric_kernels.feature_maps import FeatureMap


[docs] def sample_at( feature_map: FeatureMap, s: int, X: B.Numeric, params: Dict[str, B.Numeric], key: B.RandomState = None, normalize: bool = None, ) -> Tuple[B.RandomState, B.Numeric]: r""" Given a `feature_map` $\phi_{\nu, \kappa}: X \to \mathbb{R}^n$, where $\nu, \kappa$ are determined by `params["nu"]` and `params["lengthscale"]`, respectively, compute `s` samples of the Gaussian process with kernel .. math :: k_{\nu, \kappa}(x, y) = \langle \phi_{\nu, \kappa}(x), \phi_{\nu, \kappa}(y) \rangle_{\mathbb{R}^n} at input locations `X` and using the random state `key`. Generating a sample from $GP(0, k_{\nu, \kappa})$ is as simple as computing .. math :: \sum_{j=1}^n w_j \cdot (\phi_{\nu, \kappa}(x))_j \qquad w_j \stackrel{IID}{\sim} N(0, 1). .. note:: Fixing $w_j$, and treating $x \to (\phi_{\nu, \kappa}(x))_j$ as basis functions, while letting $x$ vary, you get an actual *function* as a sample, meaning something that can be evaluated at any $x \in X$. The way to fix $w_j$ in code is to apply :func:`~.make_deterministic` utility function. :param feature_map: The feature map $\phi_{\nu, \kappa}$ that defines the Gaussian process $GP(0, k_{\nu, \kappa})$ to sample from. :param s: The number of samples to generate. :param X: An [N, <axis>]-shaped array containing N elements of the space `feature_map` is defined on. <axis> is the shape of these elements. These are the points to evaluate the samples at. :param params: Parameters of the kernel (length scale and smoothness). :param key: random state, either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which represents a random state). :param normalize: Passed down to `feature_map` directly. Controls whether to force the average variance of the Gaussian process to be around 1 or not. If None, follows the standard behavior, typically same as normalize=True. Defaults to None. :return: [N, s]-shaped array containing s samples of the $GP(0, k_{\nu, \kappa})$ evaluated at X. """ if key is None: key = B.global_random_state(B.dtype(X)) _context, features = feature_map(X, params, key=key, normalize=normalize) # [N, M] if _context is not None: key = _context num_features = B.shape(features)[-1] key, random_weights = B.randn( key, B.dtype(params["lengthscale"]), num_features, s ) # [M, S] random_sample = B.matmul(features, random_weights) # [N, S] return key, random_sample
[docs] def sampler( feature_map: FeatureMap, s: Optional[int] = 1, **kwargs ) -> Callable[[Any], Any]: """ A helper wrapper around `sample_at` that fixes `feature_map`, `s` and the keyword arguments in ``**kwargs`` but leaves `X`, `params` and the other keyword arguments vary. :param feature_map: The feature map to fix. :param s: The number of samples parameter to fix. Defaults to 1. :param ``**kwargs``: Keyword arguments to fix. :return: The version of :func:`sample_at` with parameters `feature_map` and `s` fixed, together with the keyword arguments in ``**kwargs``. """ sample_f = partial(sample_at, feature_map, s, **kwargs) new_docstring = f""" This is a version of the `{sample_at.__name__}` function with - feature_map={feature_map}, - s={s}, - and additional keyword arguments {kwargs}. The original docstring follows. {sample_at.__doc__} """ sample_f.__name__ = sample_at.__name__ # type: ignore[attr-defined] sample_f.__doc__ = new_docstring return sample_f