Coverage for geometric_kernels/spaces/spd.py: 85%
61 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provides the :class:`SymmetricPositiveDefiniteMatrices` space.
3"""
5import geomstats as gs
6import lab as B
8from geometric_kernels.lab_extras import (
9 complex_like,
10 create_complex,
11 dtype_double,
12 from_numpy,
13 qr,
14 slogdet,
15)
16from geometric_kernels.spaces.base import NoncompactSymmetricSpace
17from geometric_kernels.utils.utils import ordered_pairwise_differences
20class SymmetricPositiveDefiniteMatrices(
21 NoncompactSymmetricSpace, gs.geometry.spd_matrices.SPDMatrices
22):
23 r"""
24 The GeometricKernels space representing the manifold of symmetric positive
25 definite matrices $SPD(n)$ with the affine-invariant Riemannian metric.
27 The elements of this space are represented by positive definite matrices of
28 size n x n. Positive definite means _strictly_ positive definite here, not
29 positive semi-definite.
31 The class inherits the interface of geomstats's `SPDMatrices`.
33 .. note::
34 A tutorial on how to use this space is available in the
35 :doc:`SPD.ipynb </examples/SPD>` notebook.
37 :param n:
38 Size of the matrices, the $n$ in $SPD(n)$.
40 .. note::
41 As mentioned in :ref:`this note <quotient note>`, any symmetric space
42 is a quotient G/H. For the manifold of symmetric positive definite
43 matrices $SPD(n)$, the group of symmetries $G$ is the identity component
44 $GL(n)_+$ of the general linear group $GL(n)$, while the isotropy
45 subgroup $H$ is the special orthogonal group $SO(n)$. See the
46 mathematical details in :cite:t:`azangulov2024b`.
48 .. admonition:: Citation
50 If you use this GeometricKernels space in your research, please consider
51 citing :cite:t:`azangulov2024b`.
52 """
54 def __init__(self, n):
55 super().__init__(n)
57 def __str__(self):
58 return f"SymmetricPositiveDefiniteMatrices({self.n})"
60 @property
61 def dimension(self) -> int:
62 """
63 Returns n(n+1)/2 where `n` was passed down to `__init__`.
64 """
65 dim = self.n * (self.n + 1) / 2
66 return dim
68 @property
69 def degree(self) -> int:
70 return self.n
72 @property
73 def rho(self):
74 return (B.range(self.degree) + 1) - (self.degree + 1) / 2
76 @property
77 def num_axes(self):
78 """
79 Number of axes in an array representing a point in the space.
81 :return:
82 2.
83 """
84 return 2
86 def random_phases(self, key, num):
87 if not isinstance(num, tuple):
88 num = (num,)
89 key, x = B.randn(key, dtype_double(key), *num, self.degree, self.degree)
90 Q, R = qr(x)
91 r_diag_sign = B.sign(B.diag_extract(R)) # [B, N]
92 Q *= B.expand_dims(r_diag_sign, -1) # [B, D, D]
93 sign_det, _ = slogdet(Q) # [B, ]
95 # equivalent to Q[..., 0] *= B.expand_dims(sign_det, -1)
96 Q0 = Q[..., 0] * B.expand_dims(sign_det, -1) # [B, D]
97 Q = B.concat(B.expand_dims(Q0, -1), Q[..., 1:], axis=-1) # [B, D, D]
98 return key, Q
100 def inv_harish_chandra(self, lam):
101 diffs = ordered_pairwise_differences(lam)
102 diffs = B.abs(diffs)
103 logprod = B.sum(
104 B.log(B.pi * diffs) + B.log(B.tanh(B.pi * diffs)), axis=-1
105 ) # [B, ]
106 return B.exp(0.5 * logprod)
108 def power_function(self, lam, g, h):
109 g = B.cholesky(g)
110 gh = B.matmul(g, h)
111 Q, R = qr(gh)
113 u = B.abs(B.diag_extract(R))
114 logu = B.cast(complex_like(R), B.log(u))
115 exponent = create_complex(from_numpy(lam, self.rho), lam) # [..., D]
116 logpower = logu * exponent # [..., D]
117 logproduct = B.sum(logpower, axis=-1) # [...,]
118 logproduct = B.cast(complex_like(lam), logproduct)
119 return B.exp(logproduct)
121 def random(self, key, number):
122 """
123 Non-uniform random sampling, reimplements the algorithm from geomstats.
125 Always returns [N, n, n] float64 array of the `key`'s backend.
127 :param key:
128 Either `np.random.RandomState`, `tf.random.Generator`,
129 `torch.Generator` or `jax.tensor` (representing random state).
130 :param number:
131 Number of samples to draw.
133 :return:
134 An array of `number` random samples on the space.
135 """
137 key, mat = B.rand(key, dtype_double(key), number, self.n, self.n)
138 mat = 2 * mat - 1
139 mat_symm = 0.5 * (mat + B.transpose(mat, (0, 2, 1)))
141 return key, B.expm(mat_symm)
143 @property
144 def element_shape(self):
145 """
146 :return:
147 [n, n].
148 """
149 return [self.n, self.n]
151 @property
152 def element_dtype(self):
153 """
154 :return:
155 B.Float.
156 """
157 return B.Float