Coverage for geometric_kernels/feature_maps/probability_densities.py: 83%
152 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provide the routines for sampling from the Gaussian and Student-t
3probability densities in a backend-agnostic way. It also provides the routines
4for sampling the non-standard probability densities that arise in relation to
5the :class:`~.spaces.Hyperbolic` and
6:class:`~.spaces.SymmetricPositiveDefiniteMatrices` spaces.
7"""
9import operator
10from functools import reduce
11from typing import List, Tuple
13import lab as B
14import numpy as np
15from beartype.typing import Dict, List, Optional, Tuple
16from sympy import Poly, Product, symbols
18from geometric_kernels.lab_extras import (
19 cumsum,
20 dtype_double,
21 dtype_integer,
22 eigvalsh,
23 from_numpy,
24)
25from geometric_kernels.utils.utils import (
26 _check_1_vector,
27 _check_field_in_params,
28 _check_matrix,
29 _check_rank_1_array,
30 ordered_pairwise_differences,
31)
34def student_t_sample(
35 key: B.RandomState,
36 loc: B.Numeric,
37 shape: B.Numeric,
38 df: B.Numeric,
39 size: Tuple[int, ...],
40 dtype: Optional[B.DType] = None,
41) -> Tuple[B.RandomState, B.Numeric]:
42 r"""
43 Sample from the multivariate Student's t-distribution with mean vector
44 `loc`, shape matrix `shape` and `df` degrees of freedom, using `key` random
45 state, and returning sample of shape `size`.
47 A multivariate Student's t random vector $T$ with $\nu$ degrees of freedom,
48 a mean vector $\mu$ and a shape matrix $\Sigma$ can be represented as
50 .. math:: T = \frac{Z}{\sqrt{V/\nu}} + \mu,
52 where $Z \sim N(0, \Sigma)$ and $V \sim \chi^2(\nu)$. The $\chi^2(\nu)$
53 distribution, in its turn, is the same as $\Gamma(\nu / 2, 2)$ distribution,
54 and therefore $V/\nu \sim \Gamma(\nu / 2, 2 / \nu)$. We use these properties
55 to sample a multivariate Student's t random vector by sampling a Gaussian
56 random vector and a Gamma random variable.
58 .. warning:: The `shape` parameter has little to do with tensor shapes,
59 it is similar to the covariance matrix in the Gaussian distribution.
61 :param key:
62 Either `np.random.RandomState`, `tf.random.Generator`,
63 `torch.Generator` or `jax.tensor` (representing random state).
64 :param loc:
65 Mean vector of the multivariate Student's t-distribution.
66 An array of shape (n,).
67 :param shape:
68 Shape matrix of the multivariate Student's t-distribution.
69 An array of shape (n, n).
70 :param size:
71 The returned array `samples` will have the shape `(*size, n)`.
72 :param df:
73 The number of degrees of freedom of the Student-t distribution,
74 represented as a (1,)-array of the used backend.
75 :param dtype:
76 dtype of the returned tensor.
78 :return:
79 `Tuple(key, samples)` where `samples` is a `(*size, n)`-shaped array of
80 samples of type `dtype`, and `key` is the updated random key for `jax`,
81 or the similar random state (generator) for any other backend.
82 """
83 _check_1_vector(df, "df")
85 _check_rank_1_array(loc, "loc")
86 _check_matrix(shape, "shape")
88 n = B.shape(loc)[0]
90 if tuple(B.shape(shape)) != (n, n):
91 raise ValueError(
92 f"`Expected `shape` matrix to have shape [{n}, {n}], but got {B.shape(shape)}."
93 )
94 shape_sqrt = B.chol(shape)
95 dtype = dtype or dtype_double(key)
96 key, z = B.randn(key, dtype, *size, n)
97 z = B.einsum("...i,ji->...j", z, shape_sqrt)
99 key, g = B.randgamma(
100 key,
101 dtype,
102 *size,
103 alpha=df / 2,
104 scale=2,
105 )
107 g = B.squeeze(g, axis=-1)
109 u = (z / B.sqrt(g / df)[..., None]) + loc
111 return key, u
114def base_density_sample(
115 key: B.RandomState,
116 size: Tuple[int, ...],
117 params: Dict[str, B.Numeric],
118 dim: int,
119 rho: B.Numeric,
120 shifted_laplacian: bool = True,
121) -> Tuple[B.RandomState, B.Numeric]:
122 r"""
123 The Matérn kernel's spectral density is $p_{\nu,\kappa}(\lambda)$, where
124 $\nu$ is the smoothness parameter, $\kappa$ is the length scale and
125 $p_{\nu,\kappa}$ is the Student's t or Gaussian density, depending on the
126 smoothness.
128 We call it "base density" and this function returns a sample from it.
130 :param key:
131 Either `np.random.RandomState`, `tf.random.Generator`,
132 `torch.Generator` or `jax.tensor` (representing random state).
133 :param size:
134 The returned array `samples` will have the shape `(*size, dim)`.
135 :param params:
136 Params of the kernel.
137 :param dim:
138 Dimensionality of the space the kernel is defined on.
139 :param rho:
140 $\rho$ vector of the space.
141 :param shifted_laplacian:
142 If True, assumes that the kernels are defined in terms of the shifted
143 Laplacian. This often makes Matérn kernels more flexible by widening
144 the effective range of the length scale parameter.
146 Defaults to True.
148 :return:
149 `Tuple(key, samples)` where `samples` is a `(*size, dim)`-shaped array
150 of samples, and `key` is the updated random key for `jax`, or the
151 similar random state (generator) for any other backend.
152 """
153 _check_field_in_params(params, "lengthscale")
154 _check_1_vector(params["lengthscale"], 'params["lengthscale"]')
156 _check_field_in_params(params, "nu")
157 _check_1_vector(params["nu"], 'params["nu"]')
159 nu = params["nu"]
160 L = params["lengthscale"]
162 rho_size = B.size(rho)
164 # Note: 1.0 in safe_nu can be replaced by any finite positive value
165 safe_nu = B.where(nu == np.inf, B.cast(B.dtype(L), np.r_[1.0]), nu)
167 # for nu == np.inf
168 # sample from Gaussian
169 key, u_nu_infinite = B.randn(key, B.dtype(L), *size, rho_size)
170 # for nu < np.inf
171 # sample from the student-t with 2\nu + dim(space) - dim(rho) degrees of freedom
172 df = 2 * safe_nu + dim - rho_size
174 dtype = B.dtype(L)
176 key, u_nu_finite = student_t_sample(
177 key,
178 B.zeros(dtype, rho_size),
179 B.eye(dtype, rho_size),
180 df,
181 size,
182 dtype,
183 )
185 u = B.where(nu == np.inf, u_nu_infinite, u_nu_finite)
187 scale_nu_infinite = L
188 if shifted_laplacian:
189 scale_nu_finite = B.sqrt(df / (2 * safe_nu / L**2))
190 else:
191 scale_nu_finite = B.sqrt(df / (2 * safe_nu / L**2 + B.sum(rho**2)))
193 scale = B.where(nu == np.inf, scale_nu_infinite, scale_nu_finite)
195 scale = B.cast(B.dtype(u), scale)
196 return key, u / scale
199def _randcat_fix(
200 key: B.RandomState, dtype: B.DType, size: int, p: B.Numeric
201) -> Tuple[B.RandomState, B.Numeric]:
202 """
203 Sample from the categorical variable with probabilities `p`.
205 :param key:
206 Either `np.random.RandomState`, `tf.random.Generator`,
207 `torch.Generator` or `jax.tensor` (representing random state).
208 :param dtype:
209 dtype of the returned tensor.
210 :param size:
211 Number of samples returned.
212 :param p:
213 Vector of (potentially unnormalized) probabilities
214 defining the discrete distribution to sample from.
216 :return:
217 `Tuple(key, samples)` where `samples` is a `(size,)`-shaped array of
218 samples of type `dtype`, and `key` is the updated random key for `jax`,
219 or the similar random state (generator) for any other backend.
220 """
221 p = p / B.sum(p, axis=-1, squeeze=False)
222 # Perform sampling routine.
223 cdf = cumsum(p, axis=-1)
224 key, u = B.rand(key, dtype, size, *B.shape(p)[:-1])
225 inds = B.argmax(B.cast(dtype_integer(key), u[..., None] < cdf[None]), axis=-1)
226 return key, B.cast(dtype, inds)
229def _alphas(n: int) -> B.Numeric:
230 r"""
231 Compute alphas for Prop. 16 & 17 of cite:t:`azangulov2024b`
232 for the hyperbolic space of dimension `n`.
234 :param n:
235 Dimension of the hyperbolic space, n >= 2.
237 :return:
238 Array of alphas.
240 .. todo::
241 Precompute these, rather than computing in runtime.
243 .. todo::
244 Update proposition numbers when the paper gets published.
245 """
246 if n < 2:
247 raise ValueError("Dimension of the hyperbolic space `n` must be >= 2.")
248 x, j = symbols("x, j")
249 if (n % 2) == 0:
250 m = n // 2
251 prod = x * Product(x**2 + (2 * j - 3) ** 2 / 4, (j, 2, m)).doit()
252 else:
253 m = (n - 1) // 2
254 prod = Product(x**2 + j**2, (j, 0, m - 1)).doit()
255 return np.array(Poly(prod, x).all_coeffs()).astype(np.float64)[::-1]
258def _sample_mixture_heat(
259 key: B.RandomState, alpha: B.Numeric, lengthscale: B.Numeric
260) -> Tuple[B.RandomState, B.Numeric]:
261 r"""
262 Sample from the mixture distribution from Prop. 16 for specific alphas
263 `alpha` and length scale ($\kappa$) `lengthscale` using `key` random state.
265 :param key:
266 Either `np.random.RandomState`, `tf.random.Generator`,
267 `torch.Generator` or `jax.tensor` (representing random state).
268 :param alpha:
269 Unnormalized coefficients of the mixture.
270 :param lengthscale:
271 Length scale ($\kappa$).
273 :return:
274 `Tuple(key, sample)` where `sample` is a single sample from the
275 distribution, and `key` is the updated random key for `jax`, or the
276 similar random state (generator) for any other backend.
278 .. todo::
279 Do we need reparameterization trick for hyperparameter optimization?
281 .. todo::
282 Update proposition numbers when the paper gets published.
283 """
284 _check_rank_1_array(alpha, "alpha")
285 m = B.shape(alpha)[0] - 1
286 if m < 0:
287 raise ValueError("The mixture must contain at least 1 component.")
288 dtype = B.dtype(lengthscale)
289 js = B.range(dtype, 0, m + 1)
291 # Gamma((js+1)/2) should be positive real, so G = exp(log(abs(G)))
292 beta = 2 ** ((1 - js) / 2) / B.exp(B.loggamma((js + 1) / 2)) * lengthscale
294 alpha = B.cast(dtype, from_numpy(beta, alpha))
295 cs_unnorm = alpha / beta
296 cs = cs_unnorm / B.sum(cs_unnorm)
297 key, ind = _randcat_fix(key, dtype, 1, cs)
299 # Gamma(nu/2, 2) distribution is the same as chi2(nu) distribution
300 key, s = B.randgamma(key, dtype, 1, alpha=(ind + 1) / 2, scale=2)
301 s = B.sqrt(s) / lengthscale
302 return key, s
305def _sample_mixture_matern(
306 key: B.RandomState,
307 alpha: B.Numeric,
308 lengthscale: B.Numeric,
309 nu: B.Numeric,
310 dim: int,
311 shifted_laplacian: bool = True,
312) -> Tuple[B.RandomState, B.Numeric]:
313 r"""
314 Sample from the mixture distribution from Prop. 17 of
315 cite:t:`azangulov2024b` for specific alphas `alpha`, length
316 scale ($\kappa$) `lengthscale`, smoothness `nu` and dimension `dim`, using
317 `key` random state.
319 :param key:
320 Either `np.random.RandomState`, `tf.random.Generator`,
321 `torch.Generator` or `jax.tensor` (representing random state).
322 :param alpha:
323 Unnormalized coefficients of the mixture.
324 :param lengthscale:
325 Length scale ($\kappa$).
326 :param nu:
327 Smoothness parameter of Matérn kernels ($\nu$).
328 :param dim:
329 Dimension of the hyperbolic space.
330 :param shifted_laplacian:
331 If True, assumes that the kernels are defined in terms of the shifted
332 Laplacian. This often makes Matérn kernels more flexible by widening
333 the effective range of the length scale parameter.
335 Defaults to True.
337 :return:
338 `Tuple(key, sample)` where `sample` is a single sample from the
339 distribution, and `key` is the updated random key for `jax`, or the
340 similar random state (generator) for any other backend.
342 .. todo::
343 Do we need reparameterization trick for hyperparameter optimization?
345 .. todo::
346 Update proposition numbers when the paper gets published.
347 """
348 _check_rank_1_array(alpha, "alpha")
349 m = B.shape(alpha)[0] - 1
350 if m < 0:
351 raise ValueError("The mixture must contain at least 1 component.")
352 dtype = B.dtype(lengthscale)
353 js = B.range(dtype, 0, m + 1)
354 if shifted_laplacian:
355 gamma = 2 * nu / lengthscale**2
356 else:
357 gamma = 2 * nu / lengthscale**2 + ((dim - 1) / 2) ** 2
359 # B(x, y) = Gamma(x) Gamma(y) / Gamma(x+y)
360 beta = B.exp(
361 B.loggamma((js + 1) / 2)
362 + B.loggamma(nu + (dim - js - 1) / 2)
363 - B.loggamma(nu + dim / 2)
364 )
365 beta = 2.0 / beta
366 beta = beta * B.power(gamma, nu + (dim - js - 1) / 2)
368 alpha = B.cast(dtype, from_numpy(beta, alpha))
369 cs_unnorm = alpha / beta
370 cs = cs_unnorm / B.sum(cs_unnorm)
372 key, ind = _randcat_fix(key, dtype, 1, cs)
373 key, p = B.randbeta(
374 key, dtype, 1, alpha=(ind + 1) / 2, beta=nu + (dim - ind - 1) / 2
375 )
376 p = p / (1 - p) # beta prime distribution
377 s = B.sqrt(gamma * p)
378 return key, s
381def hyperbolic_density_sample(
382 key: B.RandomState,
383 size: Tuple[int, ...],
384 params: Dict[str, B.Numeric],
385 dim: int,
386 shifted_laplacian: bool = True,
387) -> Tuple[B.RandomState, B.Numeric]:
388 r"""
389 This function samples from the full (i.e., including the $c(\lambda)^{-2}$
390 factor, where $c$ is the Harish-Chandra's $c$-function) spectral density of
391 the heat/Matérn kernel on the hyperbolic space, using rejection sampling.
393 :param key:
394 Either `np.random.RandomState`, `tf.random.Generator`,
395 `torch.Generator` or `jax.tensor` (representing random state).
396 :param size:
397 Shape of the returned sample.
398 :param params:
399 Params of the kernel.
400 :param dim:
401 Dimensionality of the space the kernel is defined on.
402 :param shifted_laplacian:
403 If True, assumes that the kernels are defined in terms of the shifted
404 Laplacian. This often makes Matérn kernels more flexible by widening
405 the effective range of the length scale parameter.
407 Defaults to True.
409 :return:
410 `Tuple(key, samples)` where `samples` is a `size`-shaped array of
411 samples, and `key` is the updated random key for `jax`, or the similar
412 random state (generator) for any other backend.
413 """
414 _check_field_in_params(params, "lengthscale")
415 _check_1_vector(params["lengthscale"], 'params["lengthscale"]')
417 _check_field_in_params(params, "nu")
418 _check_1_vector(params["nu"], 'params["nu"]')
420 nu = params["nu"]
421 L = params["lengthscale"]
423 alpha = _alphas(dim)
425 def base_sampler(key):
426 # Note: 1.0 in safe_nu can be replaced by any finite positive value
427 safe_nu = B.where(nu == np.inf, B.cast(B.dtype(L), np.r_[1.0]), nu)
429 # for nu == np.inf
430 key, sample_mixture_nu_infinite = _sample_mixture_heat(key, alpha, L)
431 # for nu < np.inf
432 key, sample_mixture_nu_finite = _sample_mixture_matern(
433 key, alpha, L, safe_nu, dim, shifted_laplacian
434 )
436 return key, B.where(
437 nu == np.inf, sample_mixture_nu_infinite, sample_mixture_nu_finite
438 )
440 samples: List[B.Numeric] = []
441 while len(samples) < reduce(operator.mul, size, 1):
442 key, proposal = base_sampler(key)
444 # accept with probability tanh(pi*proposal)
445 key, u = B.rand(key, B.dtype(L), 1)
446 acceptance = B.all(u < B.tanh(B.pi * proposal))
448 key, sign_z = B.rand(key, B.dtype(L), 1)
449 sign = B.sign(sign_z - 0.5) # +1 or -1 with probability 0.5
450 if ((dim % 2) == 1) or acceptance:
451 samples.append(sign * proposal)
452 samples = B.reshape(B.concat(*samples), *size)
453 return key, B.cast(B.dtype(L), samples)
456def spd_density_sample(
457 key: B.RandomState,
458 size: Tuple[int, ...],
459 params: Dict[str, B.Numeric],
460 degree: int,
461 rho: B.Numeric,
462 shifted_laplacian: bool = True,
463) -> Tuple[B.RandomState, B.Numeric]:
464 r"""
465 This function samples from the full (i.e., including the $c(\lambda)^{-2}$
466 factor, where $c$ is the Harish-Chandra's $c$-function) spectral density
467 of the heat/Matérn kernel on the manifold of symmetric positive definite
468 matrices, using rejection sampling.
470 :param key:
471 Either `np.random.RandomState`, `tf.random.Generator`,
472 `torch.Generator` or `jax.tensor` (representing random state).
473 :param size:
474 The returned array `samples` will have the shape `(*size, D)`.
475 :param params:
476 Params of the kernel.
477 :param dim:
478 Dimensionality of the space the kernel is defined on.
479 :param degree:
480 The degree D of the SPD(D) space.
481 :param rho:
482 $\rho$ vector of the space, D-dimensional.
483 :param shifted_laplacian:
484 If True, assumes that the kernels are defined in terms of the shifted
485 Laplacian. This often makes Matérn kernels more flexible by widening
486 the effective range of the length scale parameter.
488 Defaults to True.
490 :return:
491 `Tuple(key, samples)` where `samples` is a `(*size, D)`-shaped array of
492 samples, and `key` is the updated random key for `jax`, or the similar
493 random state (generator) for any other backend.
494 """
495 _check_field_in_params(params, "nu")
496 _check_1_vector(params["nu"], 'params["nu"]')
498 _check_field_in_params(params, "lengthscale")
499 _check_1_vector(params["lengthscale"], 'params["lengthscale"]')
501 nu = params["nu"]
502 L = params["lengthscale"]
504 samples: List[B.Numeric] = []
505 while len(samples) < reduce(operator.mul, size, 1):
506 key, X = B.randn(key, B.dtype(L), degree, degree)
507 M = (X + B.transpose(X)) / 2
509 eigv = eigvalsh(M) # [D]
511 # Note: 1.0 in safe_nu can be replaced by any finite positive value
512 safe_nu = B.where(nu == np.inf, B.cast(B.dtype(L), np.r_[1.0]), nu)
514 # for nu == np.inf
515 proposal_nu_infinite = eigv / L
517 # for nu < np.inf
518 if shifted_laplacian:
519 eigv_nu_finite = eigv * B.sqrt(2 * safe_nu) / L
520 else:
521 eigv_nu_finite = eigv * B.sqrt(2 * safe_nu / L**2 + B.sum(rho**2))
522 # Gamma(nu, 2) distribution is the same as chi2(2nu) distribution
523 key, chi2_sample = B.randgamma(key, B.dtype(L), 1, alpha=safe_nu, scale=2)
524 chi_sample = B.sqrt(chi2_sample)
525 proposal_nu_finite = eigv_nu_finite / chi_sample # [D]
527 proposal = B.where(nu == np.inf, proposal_nu_infinite, proposal_nu_finite)
529 diffp = ordered_pairwise_differences(proposal)
530 diffp = B.pi * B.abs(diffp)
531 logprod = B.sum(B.log(B.tanh(diffp)), axis=-1)
532 prod = B.exp(logprod)
534 # accept with probability `prod`
535 key, u = B.rand(key, B.dtype(L), 1)
536 acceptance = B.all(u < prod)
537 if acceptance:
538 samples.append(proposal)
540 samples = B.reshape(B.concat(*samples), *size, degree)
541 return key, samples