Coverage for geometric_kernels/spaces/hyperbolic.py: 70%

69 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`Hyperbolic` space. 

3""" 

4 

5import geomstats as gs 

6import lab as B 

7from beartype.typing import Optional 

8 

9from geometric_kernels.lab_extras import complex_like, create_complex, dtype_double 

10from geometric_kernels.spaces.base import NoncompactSymmetricSpace 

11from geometric_kernels.utils.manifold_utils import ( 

12 hyperbolic_distance, 

13 minkowski_inner_product, 

14) 

15 

16 

17class Hyperbolic(NoncompactSymmetricSpace, gs.geometry.hyperboloid.Hyperboloid): 

18 r""" 

19 The GeometricKernels space representing the n-dimensional hyperbolic space 

20 $\mathbb{H}_n$. We use the hyperboloid model of the hyperbolic space. 

21 

22 The elements of this space are represented by (n+1)-dimensional vectors 

23 satisfying 

24 

25 .. math:: x_0^2 - x_1^2 - \ldots - x_n^2 = 1, 

26 

27 i.e. lying on the hyperboloid. 

28 

29 The class inherits the interface of geomstats's `Hyperbolic` with 

30 `point_type=extrinsic`. 

31 

32 .. note:: 

33 A tutorial on how to use this space is available in the 

34 :doc:`Hyperbolic.ipynb </examples/Hyperbolic>` notebook. 

35 

36 :param dim: 

37 Dimension of the hyperbolic space, denoted by n in docstrings. 

38 

39 .. note:: 

40 As mentioned in :ref:`this note <quotient note>`, any symmetric space 

41 is a quotient G/H. For the hyperbolic space $\mathbb{H}_n$, the group 

42 of symmetries $G$ is the proper Lorentz group $SO(1, n)$, while the 

43 isotropy subgroup $H$ is the special orthogonal group $SO(n)$. See the 

44 mathematical details in :cite:t:`azangulov2024b`. 

45 

46 .. admonition:: Citation 

47 

48 If you use this GeometricKernels space in your research, please consider 

49 citing :cite:t:`azangulov2024b`. 

50 """ 

51 

52 def __init__(self, dim=2): 

53 super().__init__(dim=dim) 

54 

55 def __str__(self): 

56 return f"Hyperbolic({self.dimension})" 

57 

58 @property 

59 def dimension(self) -> int: 

60 """ 

61 Returns n, the `dim` parameter that was passed down to `__init__`. 

62 """ 

63 return self.dim 

64 

65 def distance( 

66 self, x1: B.Numeric, x2: B.Numeric, diag: Optional[bool] = False 

67 ) -> B.Numeric: 

68 """ 

69 Calls :func:`~.hyperbolic_distance` on the same inputs. 

70 """ 

71 return hyperbolic_distance(x1, x2, diag) 

72 

73 def inner_product(self, vector_a, vector_b): 

74 r""" 

75 Calls :func:`~.minkowski_inner_product` on `vector_a` and `vector_b`. 

76 """ 

77 return minkowski_inner_product(vector_a, vector_b) 

78 

79 def inv_harish_chandra(self, lam: B.Numeric) -> B.Numeric: 

80 lam = B.squeeze(lam, -1) 

81 if self.dimension == 2: 

82 c = B.abs(lam) * B.tanh(B.pi * B.abs(lam)) 

83 return B.sqrt(c) 

84 

85 if self.dimension % 2 == 0: 

86 m = self.dimension // 2 

87 js = B.range(B.dtype(lam), 0, m - 1) 

88 addenda = ((js * 2 + 1.0) ** 2) / 4 # [M] 

89 elif self.dimension % 2 == 1: 

90 m = self.dimension // 2 

91 js = B.range(B.dtype(lam), 0, m) 

92 addenda = js**2 # [M] 

93 log_c = B.sum(B.log(lam[..., None] ** 2 + addenda), axis=-1) # [N, M] --> [N, ] 

94 if self.dimension % 2 == 0: 

95 log_c += B.log(B.abs(lam)) + B.log(B.tanh(B.pi * B.abs(lam))) 

96 

97 return B.exp(0.5 * log_c) 

98 

99 def power_function(self, lam: B.Numeric, g: B.Numeric, h: B.Numeric) -> B.Numeric: 

100 lam = B.squeeze(lam, -1) 

101 g_poincare = self.convert_to_ball(g) # [..., n] 

102 gh_norm = B.sum(B.power(g_poincare - h, 2), axis=-1) # [N1, ..., Nk] 

103 denominator = B.log(gh_norm) 

104 numerator = B.log(1.0 - B.sum(g_poincare**2, axis=-1)) 

105 exponent = create_complex(self.rho[0], -1 * B.abs(lam)) # rho is 1-d 

106 log_out = ( 

107 B.cast(complex_like(lam), (numerator - denominator)) * exponent 

108 ) # [N1, ..., Nk] 

109 out = B.exp(log_out) 

110 return out 

111 

112 def convert_to_ball(self, point): 

113 """ 

114 Converts the `point` from the hyperboloid model to the Poincare model. 

115 This corresponds to a stereographic projection onto the ball. 

116 

117 :param: 

118 An [..., n+1]-shaped array of points on the hyperboloid. 

119 

120 :return: 

121 An [..., n]-shaped array of points in the Poincare ball. 

122 """ 

123 # point [N1, ..., Nk, n] 

124 return point[..., 1:] / (1 + point[..., :1]) 

125 

126 @property 

127 def rho(self): 

128 return B.ones(1) * (self.dimension - 1) / 2 

129 

130 @property 

131 def num_axes(self): 

132 """ 

133 Number of axes in an array representing a point in the space. 

134 

135 :return: 

136 1. 

137 """ 

138 return 1 

139 

140 def random_phases(self, key, num): 

141 if not isinstance(num, tuple): 

142 num = (num,) 

143 key, x = B.randn(key, dtype_double(key), *num, self.dimension) 

144 x = x / B.sqrt(B.sum(x**2, axis=-1, squeeze=False)) 

145 return key, x 

146 

147 def random(self, key, number): 

148 """ 

149 Non-uniform random sampling, reimplements the algorithm from geomstats. 

150 

151 Always returns [N, n+1] float64 array of the `key`'s backend. 

152 

153 :param key: 

154 Either `np.random.RandomState`, `tf.random.Generator`, 

155 `torch.Generator` or `jax.tensor` (representing random state). 

156 :param number: 

157 Number of samples to draw. 

158 

159 :return: 

160 An array of `number` random samples on the space. 

161 """ 

162 

163 key, samples = B.rand(key, dtype_double(key), number, self.dim) 

164 

165 samples = 2.0 * (samples - 0.5) 

166 

167 coord_0 = B.sqrt(1.0 + B.sum(samples**2, axis=-1)) 

168 return key, B.concat(coord_0[..., None], samples, axis=-1) 

169 

170 def element_shape(self): 

171 """ 

172 :return: 

173 [n+1]. 

174 """ 

175 return [self.dimension + 1] 

176 

177 @property 

178 def element_dtype(self): 

179 """ 

180 :return: 

181 B.Float. 

182 """ 

183 return B.Float