Source code for geometric_kernels.utils.manifold_utils

""" Utilities for dealing with manifolds.  """

import lab as B
import numpy as np
from beartype.typing import Optional

from geometric_kernels.lab_extras import from_numpy

[docs] def minkowski_inner_product(vector_a: B.Numeric, vector_b: B.Numeric) -> B.Numeric: r""" Computes the Minkowski inner product of vectors. .. math:: \langle a, b \rangle = a_0 b_0 - a_1 b_1 - \ldots - a_n b_n. :param vector_a: An [..., n+1]-shaped array of points in the hyperbolic space $\mathbb{H}_n$. :param vector_b: An [..., n+1]-shaped array of points in the hyperbolic space $\mathbb{H}_n$. :return: An [...,]-shaped array of inner products. """ assert vector_a.shape == vector_b.shape n = vector_a.shape[-1] - 1 assert n > 0 diagonal = from_numpy(vector_a, [-1.0] + [1.0] * n) # (n+1) diagonal = B.cast(B.dtype(vector_a), diagonal) return B.einsum("...i,...i->...", diagonal * vector_a, vector_b)
[docs] def hyperbolic_distance( x1: B.Numeric, x2: B.Numeric, diag: Optional[bool] = False ) -> B.Numeric: """ Compute the hyperbolic distance between `x1` and `x2`. The code is a reimplementation of `geomstats.geometry.hyperboloid.HyperbolicMetric` for `lab`. :param x1: An [N, n+1]-shaped array of points in the hyperbolic space. :param x2: An [M, n+1]-shaped array of points in the hyperbolic space. :param diag: If True, compute elementwise distance. Requires N = M. Default False. :return: An [N, M]-shaped array if diag=False or [N,]-shaped array if diag=True. """ if diag: # Compute a pointwise distance between `x1` and `x2` x1_ = x1 x2_ = x2 else: if B.rank(x1) == 1: x1 = B.expand_dims(x1) if B.rank(x2) == 1: x2 = B.expand_dims(x2) # compute pairwise distance between arrays of points `x1` and `x2` # `x1` (N, n+1) # `x2` (M, n+1) x1_ = B.tile(x1[..., None, :], 1, x2.shape[0], 1) # (N, M, n+1) x2_ = B.tile(x2[None], x1.shape[0], 1, 1) # (N, M, n+1) sq_norm_1 = minkowski_inner_product(x1_, x1_) sq_norm_2 = minkowski_inner_product(x2_, x2_) inner_prod = minkowski_inner_product(x1_, x2_) cosh_angle = -inner_prod / B.sqrt(sq_norm_1 * sq_norm_2) one = B.cast(B.dtype(cosh_angle), from_numpy(cosh_angle, [1.0])) large_constant = B.cast(B.dtype(cosh_angle), from_numpy(cosh_angle, [1e24])) # clip values into [1.0, 1e24] cosh_angle = B.where(cosh_angle < one, one, cosh_angle) cosh_angle = B.where(cosh_angle > large_constant, large_constant, cosh_angle) dist = B.log(cosh_angle + B.sqrt(cosh_angle**2 - 1)) # arccosh dist = B.cast(B.dtype(x1_), dist) return dist
[docs] def manifold_laplacian(x: B.Numeric, manifold, egrad, ehess): r""" Computes the manifold Laplacian of a given function at a given point x. The manifold Laplacian equals the trace of the manifold Hessian, i.e., $\Delta_M f(x) = \sum_{i=1}^{d} \nabla^2 f(x_i, x_i)$, where $[x_i]_{i=1}^{d}$ is an orthonormal basis of the tangent space at x. .. warning:: This function only works for hyperspheres out of the box. We will need to change that in the future. .. todo:: See warning above. :param x: A point on the manifold at which to compute the Laplacian. :param manifold: A geomstats manifold. :param egrad: Euclidean gradient of the given function at x. :param ehess: Euclidean Hessian of the given function at x. :return: Manifold Laplacian of the given function at x. See :cite:t:`jost2011` (Chapter 3.1) for mathematical details. """ dim = manifold.dim onb = tangent_onb(manifold, B.to_numpy(x)) result = 0.0 for j in range(dim): cur_vec = onb[:, j] egrad_x = B.to_numpy(egrad(x)) ehess_x = B.to_numpy(ehess(x, from_numpy(x, cur_vec))) hess_vec_prod = manifold.ehess2rhess(B.to_numpy(x), egrad_x, ehess_x, cur_vec) result += manifold.metric.inner_product( hess_vec_prod, cur_vec, base_point=B.to_numpy(x) ) return result
[docs] def tangent_onb(manifold, x): r""" Computes an orthonormal basis on the tangent space at x. .. warning:: This function only works for hyperspheres out of the box. We will need to change that in the future. .. todo:: See warning above. :param manifold: A geomstats manifold. :param x: A point on the manifold. :return: An [d, d]-shaped array containing the orthonormal basis on `manifold` at `x`. """ ambient_dim = manifold.dim + 1 manifold_dim = manifold.dim ambient_onb = np.eye(ambient_dim) projected_onb = manifold.to_tangent(ambient_onb, base_point=x) projected_onb_eigvals, projected_onb_eigvecs = np.linalg.eigh(projected_onb) # Getting rid of the zero eigenvalues: projected_onb_eigvals = projected_onb_eigvals[ambient_dim - manifold_dim :] projected_onb_eigvecs = projected_onb_eigvecs[:, ambient_dim - manifold_dim :] assert np.all(np.isclose(projected_onb_eigvals, 1.0)) return projected_onb_eigvecs