Source code for geometric_kernels.spaces.hyperbolic

"""
This module provides the :class:`Hyperbolic` space.
"""

import geomstats as gs
import lab as B
from beartype.typing import Optional

from geometric_kernels.lab_extras import complex_like, create_complex, dtype_double
from geometric_kernels.spaces.base import NoncompactSymmetricSpace
from geometric_kernels.utils.manifold_utils import (
    hyperbolic_distance,
    minkowski_inner_product,
)


[docs] class Hyperbolic(NoncompactSymmetricSpace, gs.geometry.hyperboloid.Hyperboloid): r""" The GeometricKernels space representing the n-dimensional hyperbolic space $\mathbb{H}_n$. We use the hyperboloid model of the hyperbolic space. The elements of this space are represented by (n+1)-dimensional vectors satisfying .. math:: x_0^2 - x_1^2 - \ldots - x_n^2 = 1, i.e. lying on the hyperboloid. The class inherits the interface of geomstats's `Hyperbolic` with `point_type=extrinsic`. .. note:: A tutorial on how to use this space is available in the :doc:`Hyperbolic.ipynb </examples/Hyperbolic>` notebook. :param dim: Dimension of the hyperbolic space, denoted by n in docstrings. .. note:: As mentioned in :ref:`this note <quotient note>`, any symmetric space is a quotient G/H. For the hyperbolic space $\mathbb{H}_n$, the group of symmetries $G$ is the proper Lorentz group $SO(1, n)$, while the isotropy subgroup $H$ is the special orthogonal group $SO(n)$. See the mathematical details in :cite:t:`azangulov2024b`. .. admonition:: Citation If you use this GeometricKernels space in your research, please consider citing :cite:t:`azangulov2024b`. """ def __init__(self, dim=2): super().__init__(dim=dim) def __str__(self): return f"Hyperbolic({self.dimension})" @property def dimension(self) -> int: """ Returns n, the `dim` parameter that was passed down to `__init__`. """ return self.dim
[docs] def distance( self, x1: B.Numeric, x2: B.Numeric, diag: Optional[bool] = False ) -> B.Numeric: """ Calls :func:`~.hyperbolic_distance` on the same inputs. """ return hyperbolic_distance(x1, x2, diag)
[docs] def inner_product(self, vector_a, vector_b): r""" Calls :func:`~.minkowski_inner_product` on `vector_a` and `vector_b`. """ return minkowski_inner_product(vector_a, vector_b)
[docs] def inv_harish_chandra(self, lam: B.Numeric) -> B.Numeric: lam = B.squeeze(lam, -1) if self.dimension == 2: c = B.abs(lam) * B.tanh(B.pi * B.abs(lam)) return B.sqrt(c) if self.dimension % 2 == 0: m = self.dimension // 2 js = B.range(B.dtype(lam), 0, m - 1) addenda = ((js * 2 + 1.0) ** 2) / 4 # [M] elif self.dimension % 2 == 1: m = self.dimension // 2 js = B.range(B.dtype(lam), 0, m) addenda = js**2 # [M] log_c = B.sum(B.log(lam[..., None] ** 2 + addenda), axis=-1) # [N, M] --> [N, ] if self.dimension % 2 == 0: log_c += B.log(B.abs(lam)) + B.log(B.tanh(B.pi * B.abs(lam))) return B.exp(0.5 * log_c)
[docs] def power_function(self, lam: B.Numeric, g: B.Numeric, h: B.Numeric) -> B.Numeric: lam = B.squeeze(lam, -1) g_poincare = self.convert_to_ball(g) # [..., n] gh_norm = B.sum(B.power(g_poincare - h, 2), axis=-1) # [N1, ..., Nk] denominator = B.log(gh_norm) numerator = B.log(1.0 - B.sum(g_poincare**2, axis=-1)) exponent = create_complex(self.rho[0], -1 * B.abs(lam)) # rho is 1-d log_out = ( B.cast(complex_like(lam), (numerator - denominator)) * exponent ) # [N1, ..., Nk] out = B.exp(log_out) return out
[docs] def convert_to_ball(self, point): """ Converts the `point` from the hyperboloid model to the Poincare model. This corresponds to a stereographic projection onto the ball. :param: An [..., n+1]-shaped array of points on the hyperboloid. :return: An [..., n]-shaped array of points in the Poincare ball. """ # point [N1, ..., Nk, n] return point[..., 1:] / (1 + point[..., :1])
@property def rho(self): return B.ones(1) * (self.dimension - 1) / 2 @property def num_axes(self): """ Number of axes in an array representing a point in the space. :return: 1. """ return 1
[docs] def random_phases(self, key, num): if not isinstance(num, tuple): num = (num,) key, x = B.randn(key, dtype_double(key), *num, self.dimension) x = x / B.sqrt(B.sum(x**2, axis=-1, squeeze=False)) return key, x
[docs] def random(self, key, number): """ Non-uniform random sampling, reimplements the algorithm from geomstats. Always returns [N, n+1] float64 array of the `key`'s backend. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param number: Number of samples to draw. :return: An array of `number` random samples on the space. """ key, samples = B.rand(key, dtype_double(key), number, self.dim) samples = 2.0 * (samples - 0.5) coord_0 = B.sqrt(1.0 + B.sum(samples**2, axis=-1)) return key, B.concat(coord_0[..., None], samples, axis=-1)
[docs] def element_shape(self): """ :return: [n+1]. """ return [self.dimension + 1]
@property def element_dtype(self): """ :return: B.Float. """ return B.Float