Coverage for geometric_kernels/spaces/hyperbolic.py: 70%
69 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provides the :class:`Hyperbolic` space.
3"""
5import geomstats as gs
6import lab as B
7from beartype.typing import Optional
9from geometric_kernels.lab_extras import complex_like, create_complex, dtype_double
10from geometric_kernels.spaces.base import NoncompactSymmetricSpace
11from geometric_kernels.utils.manifold_utils import (
12 hyperbolic_distance,
13 minkowski_inner_product,
14)
17class Hyperbolic(NoncompactSymmetricSpace, gs.geometry.hyperboloid.Hyperboloid):
18 r"""
19 The GeometricKernels space representing the n-dimensional hyperbolic space
20 $\mathbb{H}_n$. We use the hyperboloid model of the hyperbolic space.
22 The elements of this space are represented by (n+1)-dimensional vectors
23 satisfying
25 .. math:: x_0^2 - x_1^2 - \ldots - x_n^2 = 1,
27 i.e. lying on the hyperboloid.
29 The class inherits the interface of geomstats's `Hyperbolic` with
30 `point_type=extrinsic`.
32 .. note::
33 A tutorial on how to use this space is available in the
34 :doc:`Hyperbolic.ipynb </examples/Hyperbolic>` notebook.
36 :param dim:
37 Dimension of the hyperbolic space, denoted by n in docstrings.
39 .. note::
40 As mentioned in :ref:`this note <quotient note>`, any symmetric space
41 is a quotient G/H. For the hyperbolic space $\mathbb{H}_n$, the group
42 of symmetries $G$ is the proper Lorentz group $SO(1, n)$, while the
43 isotropy subgroup $H$ is the special orthogonal group $SO(n)$. See the
44 mathematical details in :cite:t:`azangulov2024b`.
46 .. admonition:: Citation
48 If you use this GeometricKernels space in your research, please consider
49 citing :cite:t:`azangulov2024b`.
50 """
52 def __init__(self, dim=2):
53 super().__init__(dim=dim)
55 def __str__(self):
56 return f"Hyperbolic({self.dimension})"
58 @property
59 def dimension(self) -> int:
60 """
61 Returns n, the `dim` parameter that was passed down to `__init__`.
62 """
63 return self.dim
65 def distance(
66 self, x1: B.Numeric, x2: B.Numeric, diag: Optional[bool] = False
67 ) -> B.Numeric:
68 """
69 Calls :func:`~.hyperbolic_distance` on the same inputs.
70 """
71 return hyperbolic_distance(x1, x2, diag)
73 def inner_product(self, vector_a, vector_b):
74 r"""
75 Calls :func:`~.minkowski_inner_product` on `vector_a` and `vector_b`.
76 """
77 return minkowski_inner_product(vector_a, vector_b)
79 def inv_harish_chandra(self, lam: B.Numeric) -> B.Numeric:
80 lam = B.squeeze(lam, -1)
81 if self.dimension == 2:
82 c = B.abs(lam) * B.tanh(B.pi * B.abs(lam))
83 return B.sqrt(c)
85 if self.dimension % 2 == 0:
86 m = self.dimension // 2
87 js = B.range(B.dtype(lam), 0, m - 1)
88 addenda = ((js * 2 + 1.0) ** 2) / 4 # [M]
89 elif self.dimension % 2 == 1:
90 m = self.dimension // 2
91 js = B.range(B.dtype(lam), 0, m)
92 addenda = js**2 # [M]
93 log_c = B.sum(B.log(lam[..., None] ** 2 + addenda), axis=-1) # [N, M] --> [N, ]
94 if self.dimension % 2 == 0:
95 log_c += B.log(B.abs(lam)) + B.log(B.tanh(B.pi * B.abs(lam)))
97 return B.exp(0.5 * log_c)
99 def power_function(self, lam: B.Numeric, g: B.Numeric, h: B.Numeric) -> B.Numeric:
100 lam = B.squeeze(lam, -1)
101 g_poincare = self.convert_to_ball(g) # [..., n]
102 gh_norm = B.sum(B.power(g_poincare - h, 2), axis=-1) # [N1, ..., Nk]
103 denominator = B.log(gh_norm)
104 numerator = B.log(1.0 - B.sum(g_poincare**2, axis=-1))
105 exponent = create_complex(self.rho[0], -1 * B.abs(lam)) # rho is 1-d
106 log_out = (
107 B.cast(complex_like(lam), (numerator - denominator)) * exponent
108 ) # [N1, ..., Nk]
109 out = B.exp(log_out)
110 return out
112 def convert_to_ball(self, point):
113 """
114 Converts the `point` from the hyperboloid model to the Poincare model.
115 This corresponds to a stereographic projection onto the ball.
117 :param:
118 An [..., n+1]-shaped array of points on the hyperboloid.
120 :return:
121 An [..., n]-shaped array of points in the Poincare ball.
122 """
123 # point [N1, ..., Nk, n]
124 return point[..., 1:] / (1 + point[..., :1])
126 @property
127 def rho(self):
128 return B.ones(1) * (self.dimension - 1) / 2
130 @property
131 def num_axes(self):
132 """
133 Number of axes in an array representing a point in the space.
135 :return:
136 1.
137 """
138 return 1
140 def random_phases(self, key, num):
141 if not isinstance(num, tuple):
142 num = (num,)
143 key, x = B.randn(key, dtype_double(key), *num, self.dimension)
144 x = x / B.sqrt(B.sum(x**2, axis=-1, squeeze=False))
145 return key, x
147 def random(self, key, number):
148 """
149 Non-uniform random sampling, reimplements the algorithm from geomstats.
151 Always returns [N, n+1] float64 array of the `key`'s backend.
153 :param key:
154 Either `np.random.RandomState`, `tf.random.Generator`,
155 `torch.Generator` or `jax.tensor` (representing random state).
156 :param number:
157 Number of samples to draw.
159 :return:
160 An array of `number` random samples on the space.
161 """
163 key, samples = B.rand(key, dtype_double(key), number, self.dim)
165 samples = 2.0 * (samples - 0.5)
167 coord_0 = B.sqrt(1.0 + B.sum(samples**2, axis=-1))
168 return key, B.concat(coord_0[..., None], samples, axis=-1)
170 def element_shape(self):
171 """
172 :return:
173 [n+1].
174 """
175 return [self.dimension + 1]
177 @property
178 def element_dtype(self):
179 """
180 :return:
181 B.Float.
182 """
183 return B.Float