Source code for geometric_kernels.spaces.spd

"""
This module provides the :class:`SymmetricPositiveDefiniteMatrices` space.
"""

import geomstats as gs
import lab as B

from geometric_kernels.lab_extras import (
    complex_like,
    create_complex,
    dtype_double,
    from_numpy,
    qr,
    slogdet,
)
from geometric_kernels.spaces.base import NoncompactSymmetricSpace
from geometric_kernels.utils.utils import ordered_pairwise_differences


[docs] class SymmetricPositiveDefiniteMatrices( NoncompactSymmetricSpace, gs.geometry.spd_matrices.SPDMatrices ): r""" The GeometricKernels space representing the manifold of symmetric positive definite matrices $SPD(n)$ with the affine-invariant Riemannian metric. The elements of this space are represented by positive definite matrices of size n x n. Positive definite means _strictly_ positive definite here, not positive semi-definite. The class inherits the interface of geomstats's `SPDMatrices`. .. note:: A tutorial on how to use this space is available in the :doc:`SPD.ipynb </examples/SPD>` notebook. :param n: Size of the matrices, the $n$ in $SPD(n)$. .. note:: As mentioned in :ref:`this note <quotient note>`, any symmetric space is a quotient G/H. For the manifold of symmetric positive definite matrices $SPD(n)$, the group of symmetries $G$ is the identity component $GL(n)_+$ of the general linear group $GL(n)$, while the isotropy subgroup $H$ is the special orthogonal group $SO(n)$. See the mathematical details in :cite:t:`azangulov2024b`. .. admonition:: Citation If you use this GeometricKernels space in your research, please consider citing :cite:t:`azangulov2024b`. """ def __init__(self, n): super().__init__(n) def __str__(self): return f"SymmetricPositiveDefiniteMatrices({self.n})" @property def dimension(self) -> int: """ Returns n(n+1)/2 where `n` was passed down to `__init__`. """ dim = self.n * (self.n + 1) / 2 return dim @property def degree(self) -> int: return self.n @property def rho(self): return (B.range(self.degree) + 1) - (self.degree + 1) / 2 @property def num_axes(self): """ Number of axes in an array representing a point in the space. :return: 2. """ return 2
[docs] def random_phases(self, key, num): if not isinstance(num, tuple): num = (num,) key, x = B.randn(key, dtype_double(key), *num, self.degree, self.degree) Q, R = qr(x) r_diag_sign = B.sign(B.diag_extract(R)) # [B, N] Q *= B.expand_dims(r_diag_sign, -1) # [B, D, D] sign_det, _ = slogdet(Q) # [B, ] # equivalent to Q[..., 0] *= B.expand_dims(sign_det, -1) Q0 = Q[..., 0] * B.expand_dims(sign_det, -1) # [B, D] Q = B.concat(B.expand_dims(Q0, -1), Q[..., 1:], axis=-1) # [B, D, D] return key, Q
[docs] def inv_harish_chandra(self, lam): diffs = ordered_pairwise_differences(lam) diffs = B.abs(diffs) logprod = B.sum( B.log(B.pi * diffs) + B.log(B.tanh(B.pi * diffs)), axis=-1 ) # [B, ] return B.exp(0.5 * logprod)
[docs] def power_function(self, lam, g, h): g = B.cholesky(g) gh = B.matmul(g, h) Q, R = qr(gh) u = B.abs(B.diag_extract(R)) logu = B.cast(complex_like(R), B.log(u)) exponent = create_complex(from_numpy(lam, self.rho), lam) # [..., D] logpower = logu * exponent # [..., D] logproduct = B.sum(logpower, axis=-1) # [...,] logproduct = B.cast(complex_like(lam), logproduct) return B.exp(logproduct)
[docs] def random(self, key, number): """ Non-uniform random sampling, reimplements the algorithm from geomstats. Always returns [N, n, n] float64 array of the `key`'s backend. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param number: Number of samples to draw. :return: An array of `number` random samples on the space. """ key, mat = B.rand(key, dtype_double(key), number, self.n, self.n) mat = 2 * mat - 1 mat_symm = 0.5 * (mat + B.transpose(mat, (0, 2, 1))) return key, B.expm(mat_symm)
@property def element_shape(self): """ :return: [n, n]. """ return [self.n, self.n] @property def element_dtype(self): """ :return: B.Float. """ return B.Float