Source code for geometric_kernels.feature_maps.probability_densities

"""
This module provide the routines for sampling from the Gaussian and Student-t
probability densities in a backend-agnostic way. It also provides the routines
for sampling the non-standard probability densities that arise in relation to
the :class:`~.spaces.Hyperbolic` and
:class:`~.spaces.SymmetricPositiveDefiniteMatrices` spaces.
"""

import operator
from functools import reduce
from typing import List, Tuple

import lab as B
import numpy as np
from beartype.typing import Dict, List, Optional, Tuple
from sympy import Poly, Product, symbols

from geometric_kernels.lab_extras import (
    cumsum,
    dtype_double,
    dtype_integer,
    eigvalsh,
    from_numpy,
)
from geometric_kernels.utils.utils import ordered_pairwise_differences


[docs] def student_t_sample( key: B.RandomState, loc: B.Numeric, shape: B.Numeric, df: B.Numeric, size: Tuple[int, ...], dtype: Optional[B.DType] = None, ) -> Tuple[B.RandomState, B.Numeric]: r""" Sample from the multivariate Student's t-distribution with mean vector `loc`, shape matrix `shape` and `df` degrees of freedom, using `key` random state, and returning sample of shape `size`. A multivariate Student's t random vector $T$ with $\nu$ degrees of freedom, a mean vector $\mu$ and a shape matrix $\Sigma$ can be represented as .. math:: T = \frac{Z}{\sqrt{V/\nu}} + \mu, where $Z \sim N(0, \Sigma)$ and $V \sim \chi^2(\nu)$. The $\chi^2(\nu)$ distribution, in its turn, is the same as $\Gamma(\nu / 2, 2)$ distribution, and therefore $V/\nu \sim \Gamma(\nu / 2, 2 / \nu)$. We use these properties to sample a multivariate Student's t random vector by sampling a Gaussian random vector and a Gamma random variable. .. warning:: The `shape` parameter has little to do with tensor shapes, it is similar to the covariance matrix in the Gaussian distribution. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param loc: Mean vector of the multivariate Student's t-distribution. An array of shape (n,). :param shape: Shape matrix of the multivariate Student's t-distribution. An array of shape (n, n). :param size: The returned array `samples` will have the shape `(*size, n)`. :param df: The number of degrees of freedom of the Student-t distribution, represented as a (1,)-array of the used backend. :param dtype: dtype of the returned tensor. :return: `Tuple(key, samples)` where `samples` is a `(*size, n)`-shaped array of samples of type `dtype`, and `key` is the updated random key for `jax`, or the similar random state (generator) for any other backend. """ assert B.shape(df) == (1,), "df must be a 1-vector." n = int(B.length(loc)) assert B.shape(loc) == (n,), "loc must be a 1-dim vector" assert B.shape(shape) == (n, n), "shape must be a matrix" shape_sqrt = B.chol(shape) dtype = dtype or dtype_double(key) key, z = B.randn(key, dtype, *size, n) z = B.einsum("...i,ji->...j", z, shape_sqrt) key, g = B.randgamma( key, dtype, *size, alpha=df / 2, scale=2, ) g = B.squeeze(g, axis=-1) u = (z / B.sqrt(g / df)[..., None]) + loc return key, u
[docs] def base_density_sample( key: B.RandomState, size: Tuple[int, ...], params: Dict[str, B.Numeric], dim: int, rho: B.Numeric, shifted_laplacian: bool = True, ) -> Tuple[B.RandomState, B.Numeric]: r""" The Matérn kernel's spectral density is $p_{\nu,\kappa}(\lambda)$, where $\nu$ is the smoothness parameter, $\kappa$ is the length scale and $p_{\nu,\kappa}$ is the Student's t or Gaussian density, depending on the smoothness. We call it "base density" and this function returns a sample from it. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param size: The returned array `samples` will have the shape `(*size, dim)`. :param params: Params of the kernel. :param dim: Dimensionality of the space the kernel is defined on. :param rho: $\rho$ vector of the space. :param shifted_laplacian: If True, assumes that the kernels are defined in terms of the shifted Laplacian. This often makes Matérn kernels more flexible by widening the effective range of the length scale parameter. Defaults to True. :return: `Tuple(key, samples)` where `samples` is a `(*size, dim)`-shaped array of samples, and `key` is the updated random key for `jax`, or the similar random state (generator) for any other backend. """ assert "lengthscale" in params assert params["lengthscale"].shape == (1,) assert "nu" in params assert params["nu"].shape == (1,) nu = params["nu"] L = params["lengthscale"] rho_size = B.size(rho) # Note: 1.0 in safe_nu can be replaced by any finite positive value safe_nu = B.where(nu == np.inf, B.cast(B.dtype(L), np.r_[1.0]), nu) # for nu == np.inf # sample from Gaussian key, u_nu_infinite = B.randn(key, B.dtype(L), *size, rho_size) # for nu < np.inf # sample from the student-t with 2\nu + dim(space) - dim(rho) degrees of freedom df = 2 * safe_nu + dim - rho_size dtype = B.dtype(L) key, u_nu_finite = student_t_sample( key, B.zeros(dtype, rho_size), B.eye(dtype, rho_size), df, size, dtype, ) u = B.where(nu == np.inf, u_nu_infinite, u_nu_finite) scale_nu_infinite = L if shifted_laplacian: scale_nu_finite = B.sqrt(df / (2 * safe_nu / L**2)) else: scale_nu_finite = B.sqrt(df / (2 * safe_nu / L**2 + B.sum(rho**2))) scale = B.where(nu == np.inf, scale_nu_infinite, scale_nu_finite) scale = B.cast(B.dtype(u), scale) return key, u / scale
def _randcat_fix( key: B.RandomState, dtype: B.DType, size: int, p: B.Numeric ) -> Tuple[B.RandomState, B.Numeric]: """ Sample from the categorical variable with probabilities `p`. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param dtype: dtype of the returned tensor. :param size: Number of samples returned. :param p: Vector of (potentially unnormalized) probabilities defining the discrete distribution to sample from. :return: `Tuple(key, samples)` where `samples` is a `(size,)`-shaped array of samples of type `dtype`, and `key` is the updated random key for `jax`, or the similar random state (generator) for any other backend. """ p = p / B.sum(p, axis=-1, squeeze=False) # Perform sampling routine. cdf = cumsum(p, axis=-1) key, u = B.rand(key, dtype, size, *B.shape(p)[:-1]) inds = B.argmax(B.cast(dtype_integer(key), u[..., None] < cdf[None]), axis=-1) return key, B.cast(dtype, inds) def _alphas(n: int) -> B.Numeric: r""" Compute alphas for Prop. 16 & 17 of cite:t:`azangulov2024b` for the hyperbolic space of dimension `n`. :param n: Dimension of the hyperbolic space, n >= 2. :return: Array of alphas. .. todo:: Precompute these, rather than computing in runtime. .. todo:: Update proposition numbers when the paper gets published. """ assert n >= 2 x, j = symbols("x, j") if (n % 2) == 0: m = n // 2 prod = x * Product(x**2 + (2 * j - 3) ** 2 / 4, (j, 2, m)).doit() else: m = (n - 1) // 2 prod = Product(x**2 + j**2, (j, 0, m - 1)).doit() return np.array(Poly(prod, x).all_coeffs()).astype(np.float64)[::-1] def _sample_mixture_heat( key: B.RandomState, alpha: B.Numeric, lengthscale: B.Numeric ) -> Tuple[B.RandomState, B.Numeric]: r""" Sample from the mixture distribution from Prop. 16 for specific alphas `alpha` and length scale ($\kappa$) `lengthscale` using `key` random state. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param alpha: Unnormalized coefficients of the mixture. :param lengthscale: Length scale ($\kappa$). :return: `Tuple(key, sample)` where `sample` is a single sample from the distribution, and `key` is the updated random key for `jax`, or the similar random state (generator) for any other backend. .. todo:: Do we need reparameterization trick for hyperparameter optimization? .. todo:: Update proposition numbers when the paper gets published. """ assert B.rank(alpha) == 1 m = B.shape(alpha)[0] - 1 assert m >= 0 dtype = B.dtype(lengthscale) js = B.range(dtype, 0, m + 1) # Gamma((js+1)/2) should be positive real, so G = exp(log(abs(G))) beta = 2 ** ((1 - js) / 2) / B.exp(B.loggamma((js + 1) / 2)) * lengthscale alpha = B.cast(dtype, from_numpy(beta, alpha)) cs_unnorm = alpha / beta cs = cs_unnorm / B.sum(cs_unnorm) key, ind = _randcat_fix(key, dtype, 1, cs) # Gamma(nu/2, 2) distribution is the same as chi2(nu) distribution key, s = B.randgamma(key, dtype, 1, alpha=(ind + 1) / 2, scale=2) s = B.sqrt(s) / lengthscale return key, s def _sample_mixture_matern( key: B.RandomState, alpha: B.Numeric, lengthscale: B.Numeric, nu: B.Numeric, dim: int, shifted_laplacian: bool = True, ) -> Tuple[B.RandomState, B.Numeric]: r""" Sample from the mixture distribution from Prop. 17 of cite:t:`azangulov2024b` for specific alphas `alpha`, length scale ($\kappa$) `lengthscale`, smoothness `nu` and dimension `dim`, using `key` random state. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param alpha: Unnormalized coefficients of the mixture. :param lengthscale: Length scale ($\kappa$). :param nu: Smoothness parameter of Matérn kernels ($\nu$). :param dim: Dimension of the hyperbolic space. :param shifted_laplacian: If True, assumes that the kernels are defined in terms of the shifted Laplacian. This often makes Matérn kernels more flexible by widening the effective range of the length scale parameter. Defaults to True. :return: `Tuple(key, sample)` where `sample` is a single sample from the distribution, and `key` is the updated random key for `jax`, or the similar random state (generator) for any other backend. .. todo:: Do we need reparameterization trick for hyperparameter optimization? .. todo:: Update proposition numbers when the paper gets published. """ assert B.rank(alpha) == 1 m = B.shape(alpha)[0] - 1 assert m >= 0 dtype = B.dtype(lengthscale) js = B.range(dtype, 0, m + 1) if shifted_laplacian: gamma = 2 * nu / lengthscale**2 else: gamma = 2 * nu / lengthscale**2 + ((dim - 1) / 2) ** 2 # B(x, y) = Gamma(x) Gamma(y) / Gamma(x+y) beta = B.exp( B.loggamma((js + 1) / 2) + B.loggamma(nu + (dim - js - 1) / 2) - B.loggamma(nu + dim / 2) ) beta = 2.0 / beta beta = beta * B.power(gamma, nu + (dim - js - 1) / 2) alpha = B.cast(dtype, from_numpy(beta, alpha)) cs_unnorm = alpha / beta cs = cs_unnorm / B.sum(cs_unnorm) key, ind = _randcat_fix(key, dtype, 1, cs) key, p = B.randbeta( key, dtype, 1, alpha=(ind + 1) / 2, beta=nu + (dim - ind - 1) / 2 ) p = p / (1 - p) # beta prime distribution s = B.sqrt(gamma * p) return key, s
[docs] def hyperbolic_density_sample( key: B.RandomState, size: Tuple[int, ...], params: Dict[str, B.Numeric], dim: int, shifted_laplacian: bool = True, ) -> Tuple[B.RandomState, B.Numeric]: r""" This function samples from the full (i.e., including the $c(\lambda)^{-2}$ factor, where $c$ is the Harish-Chandra's $c$-function) spectral density of the heat/Matérn kernel on the hyperbolic space, using rejection sampling. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param size: Shape of the returned sample. :param params: Params of the kernel. :param dim: Dimensionality of the space the kernel is defined on. :param shifted_laplacian: If True, assumes that the kernels are defined in terms of the shifted Laplacian. This often makes Matérn kernels more flexible by widening the effective range of the length scale parameter. Defaults to True. :return: `Tuple(key, samples)` where `samples` is a `size`-shaped array of samples, and `key` is the updated random key for `jax`, or the similar random state (generator) for any other backend. """ assert "lengthscale" in params assert params["lengthscale"].shape == (1,) assert "nu" in params assert params["nu"].shape == (1,) nu = params["nu"] L = params["lengthscale"] alpha = _alphas(dim) def base_sampler(key): # Note: 1.0 in safe_nu can be replaced by any finite positive value safe_nu = B.where(nu == np.inf, B.cast(B.dtype(L), np.r_[1.0]), nu) # for nu == np.inf key, sample_mixture_nu_infinite = _sample_mixture_heat(key, alpha, L) # for nu < np.inf key, sample_mixture_nu_finite = _sample_mixture_matern( key, alpha, L, safe_nu, dim, shifted_laplacian ) return key, B.where( nu == np.inf, sample_mixture_nu_infinite, sample_mixture_nu_finite ) samples: List[B.Numeric] = [] while len(samples) < reduce(operator.mul, size, 1): key, proposal = base_sampler(key) # accept with probability tanh(pi*proposal) key, u = B.rand(key, B.dtype(L), 1) acceptance = B.all(u < B.tanh(B.pi * proposal)) key, sign_z = B.rand(key, B.dtype(L), 1) sign = B.sign(sign_z - 0.5) # +1 or -1 with probability 0.5 if ((dim % 2) == 1) or acceptance: samples.append(sign * proposal) samples = B.reshape(B.concat(*samples), *size) return key, B.cast(B.dtype(L), samples)
[docs] def spd_density_sample( key: B.RandomState, size: Tuple[int, ...], params: Dict[str, B.Numeric], degree: int, rho: B.Numeric, shifted_laplacian: bool = True, ) -> Tuple[B.RandomState, B.Numeric]: r""" This function samples from the full (i.e., including the $c(\lambda)^{-2}$ factor, where $c$ is the Harish-Chandra's $c$-function) spectral density of the heat/Matérn kernel on the manifold of symmetric positive definite matrices, using rejection sampling. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param size: The returned array `samples` will have the shape `(*size, D)`. :param params: Params of the kernel. :param dim: Dimensionality of the space the kernel is defined on. :param degree: The degree D of the SPD(D) space. :param rho: $\rho$ vector of the space, D-dimensional. :param shifted_laplacian: If True, assumes that the kernels are defined in terms of the shifted Laplacian. This often makes Matérn kernels more flexible by widening the effective range of the length scale parameter. Defaults to True. :return: `Tuple(key, samples)` where `samples` is a `(*size, D)`-shaped array of samples, and `key` is the updated random key for `jax`, or the similar random state (generator) for any other backend. """ assert "lengthscale" in params assert params["lengthscale"].shape == (1,) assert "nu" in params assert params["nu"].shape == (1,) nu = params["nu"] L = params["lengthscale"] samples: List[B.Numeric] = [] while len(samples) < reduce(operator.mul, size, 1): key, X = B.randn(key, B.dtype(L), degree, degree) M = (X + B.transpose(X)) / 2 eigv = eigvalsh(M) # [D] # Note: 1.0 in safe_nu can be replaced by any finite positive value safe_nu = B.where(nu == np.inf, B.cast(B.dtype(L), np.r_[1.0]), nu) # for nu == np.inf proposal_nu_infinite = eigv / L # for nu < np.inf if shifted_laplacian: eigv_nu_finite = eigv * B.sqrt(2 * safe_nu) / L else: eigv_nu_finite = eigv * B.sqrt(2 * safe_nu / L**2 + B.sum(rho**2)) # Gamma(nu, 2) distribution is the same as chi2(2nu) distribution key, chi2_sample = B.randgamma(key, B.dtype(L), 1, alpha=safe_nu, scale=2) chi_sample = B.sqrt(chi2_sample) proposal_nu_finite = eigv_nu_finite / chi_sample # [D] proposal = B.where(nu == np.inf, proposal_nu_infinite, proposal_nu_finite) diffp = ordered_pairwise_differences(proposal) diffp = B.pi * B.abs(diffp) logprod = B.sum(B.log(B.tanh(diffp)), axis=-1) prod = B.exp(logprod) assert B.all(prod > 0) # accept with probability `prod` key, u = B.rand(key, B.dtype(L), 1) acceptance = B.all(u < prod) if acceptance: samples.append(proposal) samples = B.reshape(B.concat(*samples), *size, degree) return key, samples