Source code for geometric_kernels.spaces.hyperbolic

"""
This module provides the :class:`Hyperbolic` space.
"""

import geomstats as gs
import lab as B
from beartype.typing import Optional

from geometric_kernels.lab_extras import (
    complex_like,
    create_complex,
    dtype_double,
    from_numpy,
)
from geometric_kernels.spaces.base import NoncompactSymmetricSpace


[docs] class Hyperbolic(NoncompactSymmetricSpace, gs.geometry.hyperboloid.Hyperboloid): r""" The GeometricKernels space representing the n-dimensional hyperbolic space $\mathbb{H}_n$. We use the hyperboloid model of the hyperbolic space. The elements of this space are represented by (n+1)-dimensional vectors satisfying .. math:: x_0^2 - x_1^2 - \ldots - x_n^2 = 1, i.e. lying on the hyperboloid. The class inherits the interface of geomstats's `Hyperbolic` with `point_type=extrinsic`. .. note:: A tutorial on how to use this space is available in the :doc:`Hyperbolic.ipynb </examples/Hyperbolic>` notebook. :param dim: Dimension of the hyperbolic space, denoted by n in docstrings. .. note:: As mentioned in :ref:`this note <quotient note>`, any symmetric space is a quotient G/H. For the hyperbolic space $\mathbb{H}_n$, the group of symmetries $G$ is the proper Lorentz group $SO(1, n)$, while the isotropy subgroup $H$ is the special orthogonal group $SO(n)$. See the mathematical details in :cite:t:`azangulov2023`. .. admonition:: Citation If you use this GeometricKernels space in your research, please consider citing :cite:t:`azangulov2023`. """ def __init__(self, dim=2): super().__init__(dim=dim) @property def dimension(self) -> int: """ Returns n, the `dim` parameter that was passed down to `__init__`. """ return self.dim
[docs] def distance( self, x1: B.Numeric, x2: B.Numeric, diag: Optional[bool] = False ) -> B.Numeric: """ Compute the hyperbolic distance between `x1` and `x2`. The code is a reimplementation of `geomstats.geometry.hyperboloid.HyperbolicMetric` for `lab`. :param x1: An [N, n+1]-shaped array of points in the hyperbolic space. :param x2: An [M, n+1]-shaped array of points in the hyperbolic space. :param diag: If True, compute elementwise distance. Requires N = M. Default False. :return: An [N, M]-shaped array if diag=False or [N,]-shaped array if diag=True. """ if diag: # Compute a pointwise distance between `x1` and `x2` x1_ = x1 x2_ = x2 else: if B.rank(x1) == 1: x1 = B.expand_dims(x1) if B.rank(x2) == 1: x2 = B.expand_dims(x2) # compute pairwise distance between arrays of points `x1` and `x2` # `x1` (N, n+1) # `x2` (M, n+1) x1_ = B.tile(x1[..., None, :], 1, x2.shape[0], 1) # (N, M, n+1) x2_ = B.tile(x2[None], x1.shape[0], 1, 1) # (N, M, n+1) sq_norm_1 = self.inner_product(x1_, x1_) sq_norm_2 = self.inner_product(x2_, x2_) inner_prod = self.inner_product(x1_, x2_) cosh_angle = -inner_prod / B.sqrt(sq_norm_1 * sq_norm_2) one = B.cast(B.dtype(cosh_angle), from_numpy(cosh_angle, [1.0])) large_constant = B.cast(B.dtype(cosh_angle), from_numpy(cosh_angle, [1e24])) # clip values into [1.0, 1e24] cosh_angle = B.where(cosh_angle < one, one, cosh_angle) cosh_angle = B.where(cosh_angle > large_constant, large_constant, cosh_angle) dist = B.log(cosh_angle + B.sqrt(cosh_angle**2 - 1)) # arccosh dist = B.cast(B.dtype(x1_), dist) return dist
[docs] def inner_product(self, vector_a, vector_b): r""" Computes the Minkowski inner product of vectors. .. math:: \langle a, b \rangle = a_0 b_0 - a_1 b_1 - \ldots - a_n b_n. :param vector_a: An [..., n+1]-shaped array of points in the hyperbolic space. :param vector_b: An [..., n+1]-shaped array of points in the hyperbolic space. :return: An [...,]-shaped array of inner products. """ q = self.dimension p = 1 diagonal = from_numpy(vector_a, [-1.0] * p + [1.0] * q) # (n+1) diagonal = B.cast(B.dtype(vector_a), diagonal) return B.einsum("...i,...i->...", diagonal * vector_a, vector_b)
[docs] def inv_harish_chandra(self, lam: B.Numeric) -> B.Numeric: lam = B.squeeze(lam, -1) if self.dimension == 2: c = B.abs(lam) * B.tanh(B.pi * B.abs(lam)) return B.sqrt(c) if self.dimension % 2 == 0: m = self.dimension // 2 js = B.range(B.dtype(lam), 0, m - 1) addenda = ((js * 2 + 1.0) ** 2) / 4 # [M] elif self.dimension % 2 == 1: m = self.dimension // 2 js = B.range(B.dtype(lam), 0, m) addenda = js**2 # [M] log_c = B.sum(B.log(lam[..., None] ** 2 + addenda), axis=-1) # [N, M] --> [N, ] if self.dimension % 2 == 0: log_c += B.log(B.abs(lam)) + B.log(B.tanh(B.pi * B.abs(lam))) return B.exp(0.5 * log_c)
[docs] def power_function(self, lam: B.Numeric, g: B.Numeric, h: B.Numeric) -> B.Numeric: lam = B.squeeze(lam, -1) g_poincare = self.convert_to_ball(g) # [..., n] gh_norm = B.sum(B.power(g_poincare - h, 2), axis=-1) # [N1, ..., Nk] denominator = B.log(gh_norm) numerator = B.log(1.0 - B.sum(g_poincare**2, axis=-1)) exponent = create_complex(self.rho[0], -1 * B.abs(lam)) # rho is 1-d log_out = ( B.cast(complex_like(lam), (numerator - denominator)) * exponent ) # [N1, ..., Nk] out = B.exp(log_out) return out
[docs] def convert_to_ball(self, point): """ Converts the `point` from the hyperboloid model to the Poincare model. This corresponds to a stereographic projection onto the ball. :param: An [..., n+1]-shaped array of points on the hyperboloid. :return: An [..., n]-shaped array of points in the Poincare ball. """ # point [N1, ..., Nk, n] return point[..., 1:] / (1 + point[..., :1])
@property def rho(self): return B.ones(1) * (self.dimension - 1) / 2 @property def num_axes(self): """ Number of axes in an array representing a point in the space. :return: 1. """ return 1
[docs] def random_phases(self, key, num): if not isinstance(num, tuple): num = (num,) key, x = B.randn(key, dtype_double(key), *num, self.dimension) x = x / B.sqrt(B.sum(x**2, axis=-1, squeeze=False)) return key, x
[docs] def random(self, key, number): """ Geomstats-based non-uniform random sampling. Always returns [N, n+1] float64 array of the `key`'s backend. :param key: Either `np.random.RandomState`, `tf.random.Generator`, `torch.Generator` or `jax.tensor` (representing random state). :param number: Number of samples to draw. :return: An array of `number` uniformly random samples on the space. """ return key, B.cast(dtype_double(key), self.random_point(number))
[docs] def element_shape(self): """ :return: [n+1]. """ return [self.dimension + 1]