Coverage for geometric_kernels/spaces/spd.py: 85%

61 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`SymmetricPositiveDefiniteMatrices` space. 

3""" 

4 

5import geomstats as gs 

6import lab as B 

7 

8from geometric_kernels.lab_extras import ( 

9 complex_like, 

10 create_complex, 

11 dtype_double, 

12 from_numpy, 

13 qr, 

14 slogdet, 

15) 

16from geometric_kernels.spaces.base import NoncompactSymmetricSpace 

17from geometric_kernels.utils.utils import ordered_pairwise_differences 

18 

19 

20class SymmetricPositiveDefiniteMatrices( 

21 NoncompactSymmetricSpace, gs.geometry.spd_matrices.SPDMatrices 

22): 

23 r""" 

24 The GeometricKernels space representing the manifold of symmetric positive 

25 definite matrices $SPD(n)$ with the affine-invariant Riemannian metric. 

26 

27 The elements of this space are represented by positive definite matrices of 

28 size n x n. Positive definite means _strictly_ positive definite here, not 

29 positive semi-definite. 

30 

31 The class inherits the interface of geomstats's `SPDMatrices`. 

32 

33 .. note:: 

34 A tutorial on how to use this space is available in the 

35 :doc:`SPD.ipynb </examples/SPD>` notebook. 

36 

37 :param n: 

38 Size of the matrices, the $n$ in $SPD(n)$. 

39 

40 .. note:: 

41 As mentioned in :ref:`this note <quotient note>`, any symmetric space 

42 is a quotient G/H. For the manifold of symmetric positive definite 

43 matrices $SPD(n)$, the group of symmetries $G$ is the identity component 

44 $GL(n)_+$ of the general linear group $GL(n)$, while the isotropy 

45 subgroup $H$ is the special orthogonal group $SO(n)$. See the 

46 mathematical details in :cite:t:`azangulov2024b`. 

47 

48 .. admonition:: Citation 

49 

50 If you use this GeometricKernels space in your research, please consider 

51 citing :cite:t:`azangulov2024b`. 

52 """ 

53 

54 def __init__(self, n): 

55 super().__init__(n) 

56 

57 def __str__(self): 

58 return f"SymmetricPositiveDefiniteMatrices({self.n})" 

59 

60 @property 

61 def dimension(self) -> int: 

62 """ 

63 Returns n(n+1)/2 where `n` was passed down to `__init__`. 

64 """ 

65 dim = self.n * (self.n + 1) / 2 

66 return dim 

67 

68 @property 

69 def degree(self) -> int: 

70 return self.n 

71 

72 @property 

73 def rho(self): 

74 return (B.range(self.degree) + 1) - (self.degree + 1) / 2 

75 

76 @property 

77 def num_axes(self): 

78 """ 

79 Number of axes in an array representing a point in the space. 

80 

81 :return: 

82 2. 

83 """ 

84 return 2 

85 

86 def random_phases(self, key, num): 

87 if not isinstance(num, tuple): 

88 num = (num,) 

89 key, x = B.randn(key, dtype_double(key), *num, self.degree, self.degree) 

90 Q, R = qr(x) 

91 r_diag_sign = B.sign(B.diag_extract(R)) # [B, N] 

92 Q *= B.expand_dims(r_diag_sign, -1) # [B, D, D] 

93 sign_det, _ = slogdet(Q) # [B, ] 

94 

95 # equivalent to Q[..., 0] *= B.expand_dims(sign_det, -1) 

96 Q0 = Q[..., 0] * B.expand_dims(sign_det, -1) # [B, D] 

97 Q = B.concat(B.expand_dims(Q0, -1), Q[..., 1:], axis=-1) # [B, D, D] 

98 return key, Q 

99 

100 def inv_harish_chandra(self, lam): 

101 diffs = ordered_pairwise_differences(lam) 

102 diffs = B.abs(diffs) 

103 logprod = B.sum( 

104 B.log(B.pi * diffs) + B.log(B.tanh(B.pi * diffs)), axis=-1 

105 ) # [B, ] 

106 return B.exp(0.5 * logprod) 

107 

108 def power_function(self, lam, g, h): 

109 g = B.cholesky(g) 

110 gh = B.matmul(g, h) 

111 Q, R = qr(gh) 

112 

113 u = B.abs(B.diag_extract(R)) 

114 logu = B.cast(complex_like(R), B.log(u)) 

115 exponent = create_complex(from_numpy(lam, self.rho), lam) # [..., D] 

116 logpower = logu * exponent # [..., D] 

117 logproduct = B.sum(logpower, axis=-1) # [...,] 

118 logproduct = B.cast(complex_like(lam), logproduct) 

119 return B.exp(logproduct) 

120 

121 def random(self, key, number): 

122 """ 

123 Non-uniform random sampling, reimplements the algorithm from geomstats. 

124 

125 Always returns [N, n, n] float64 array of the `key`'s backend. 

126 

127 :param key: 

128 Either `np.random.RandomState`, `tf.random.Generator`, 

129 `torch.Generator` or `jax.tensor` (representing random state). 

130 :param number: 

131 Number of samples to draw. 

132 

133 :return: 

134 An array of `number` random samples on the space. 

135 """ 

136 

137 key, mat = B.rand(key, dtype_double(key), number, self.n, self.n) 

138 mat = 2 * mat - 1 

139 mat_symm = 0.5 * (mat + B.transpose(mat, (0, 2, 1))) 

140 

141 return key, B.expm(mat_symm) 

142 

143 @property 

144 def element_shape(self): 

145 """ 

146 :return: 

147 [n, n]. 

148 """ 

149 return [self.n, self.n] 

150 

151 @property 

152 def element_dtype(self): 

153 """ 

154 :return: 

155 B.Float. 

156 """ 

157 return B.Float