Coverage for geometric_kernels/utils/manifold_utils.py: 91%
56 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""Utilities for dealing with manifolds."""
3import lab as B
4import numpy as np
5from beartype.typing import Optional
7from geometric_kernels.lab_extras import from_numpy
10def minkowski_inner_product(vector_a: B.Numeric, vector_b: B.Numeric) -> B.Numeric:
11 r"""
12 Computes the Minkowski inner product of vectors.
14 .. math:: \langle a, b \rangle = a_0 b_0 - a_1 b_1 - \ldots - a_n b_n.
16 :param vector_a:
17 An [..., n+1]-shaped array of points in the hyperbolic space $\mathbb{H}_n$.
18 :param vector_b:
19 An [..., n+1]-shaped array of points in the hyperbolic space $\mathbb{H}_n$.
21 :return:
22 An [...,]-shaped array of inner products.
23 """
24 if B.shape(vector_a) != B.shape(vector_b):
25 raise ValueError("`vector_a` and `vector_b` must have the same shapes.")
26 n = vector_a.shape[-1] - 1
27 if n == 0:
28 raise ValueError("Must have at least 1 point.")
29 diagonal = from_numpy(vector_a, [-1.0] + [1.0] * n) # (n+1)
30 diagonal = B.cast(B.dtype(vector_a), diagonal)
31 return B.einsum("...i,...i->...", diagonal * vector_a, vector_b)
34def hyperbolic_distance(
35 x1: B.Numeric, x2: B.Numeric, diag: Optional[bool] = False
36) -> B.Numeric:
37 """
38 Compute the hyperbolic distance between `x1` and `x2`.
40 The code is a reimplementation of
41 `geomstats.geometry.hyperboloid.HyperbolicMetric` for `lab`.
43 :param x1:
44 An [N, n+1]-shaped array of points in the hyperbolic space.
45 :param x2:
46 An [M, n+1]-shaped array of points in the hyperbolic space.
47 :param diag:
48 If True, compute elementwise distance. Requires N = M.
50 Default False.
52 :return:
53 An [N, M]-shaped array if diag=False or [N,]-shaped array
54 if diag=True.
55 """
56 if diag:
57 # Compute a pointwise distance between `x1` and `x2`
58 x1_ = x1
59 x2_ = x2
60 else:
61 if B.rank(x1) == 1:
62 x1 = B.expand_dims(x1)
63 if B.rank(x2) == 1:
64 x2 = B.expand_dims(x2)
66 # compute pairwise distance between arrays of points `x1` and `x2`
67 # `x1` (N, n+1)
68 # `x2` (M, n+1)
69 x1_ = B.tile(x1[..., None, :], 1, x2.shape[0], 1) # (N, M, n+1)
70 x2_ = B.tile(x2[None], x1.shape[0], 1, 1) # (N, M, n+1)
72 sq_norm_1 = minkowski_inner_product(x1_, x1_)
73 sq_norm_2 = minkowski_inner_product(x2_, x2_)
74 inner_prod = minkowski_inner_product(x1_, x2_)
76 cosh_angle = -inner_prod / B.sqrt(sq_norm_1 * sq_norm_2)
78 one = B.cast(B.dtype(cosh_angle), from_numpy(cosh_angle, [1.0]))
79 large_constant = B.cast(B.dtype(cosh_angle), from_numpy(cosh_angle, [1e24]))
81 # clip values into [1.0, 1e24]
82 cosh_angle = B.where(cosh_angle < one, one, cosh_angle)
83 cosh_angle = B.where(cosh_angle > large_constant, large_constant, cosh_angle)
85 dist = B.log(cosh_angle + B.sqrt(cosh_angle**2 - 1)) # arccosh
86 dist = B.cast(B.dtype(x1_), dist)
87 return dist
90def manifold_laplacian(x: B.Numeric, manifold, egrad, ehess):
91 r"""
92 Computes the manifold Laplacian of a given function at a given point x.
93 The manifold Laplacian equals the trace of the manifold Hessian, i.e.,
94 $\Delta_M f(x) = \sum_{i=1}^{d} \nabla^2 f(x_i, x_i)$, where
95 $[x_i]_{i=1}^{d}$ is an orthonormal basis of the tangent space at x.
97 .. warning::
98 This function only works for hyperspheres out of the box. We will
99 need to change that in the future.
101 .. todo::
102 See warning above.
104 :param x:
105 A point on the manifold at which to compute the Laplacian.
106 :param manifold:
107 A geomstats manifold.
108 :param egrad:
109 Euclidean gradient of the given function at x.
110 :param ehess:
111 Euclidean Hessian of the given function at x.
113 :return:
114 Manifold Laplacian of the given function at x.
116 See :cite:t:`jost2011` (Chapter 3.1) for mathematical details.
117 """
118 dim = manifold.dim
120 onb = tangent_onb(manifold, B.to_numpy(x))
121 result = 0.0
122 for j in range(dim):
123 cur_vec = onb[:, j]
124 egrad_x = B.to_numpy(egrad(x))
125 ehess_x = B.to_numpy(ehess(x, from_numpy(x, cur_vec)))
126 hess_vec_prod = manifold.ehess2rhess(B.to_numpy(x), egrad_x, ehess_x, cur_vec)
127 result += manifold.metric.inner_product(
128 hess_vec_prod, cur_vec, base_point=B.to_numpy(x)
129 )
131 return result
134def tangent_onb(manifold, x):
135 r"""
136 Computes an orthonormal basis on the tangent space at x.
138 .. warning::
139 This function only works for hyperspheres out of the box. We will
140 need to change that in the future.
142 .. todo::
143 See warning above.
145 :param manifold:
146 A geomstats manifold.
147 :param x:
148 A point on the manifold.
150 :return:
151 An [d, d]-shaped array containing the orthonormal basis
152 on `manifold` at `x`.
153 """
154 ambient_dim = manifold.dim + 1
155 manifold_dim = manifold.dim
156 ambient_onb = np.eye(ambient_dim)
158 projected_onb = manifold.to_tangent(ambient_onb, base_point=x)
160 projected_onb_eigvals, projected_onb_eigvecs = np.linalg.eigh(projected_onb)
162 # Getting rid of the zero eigenvalues:
163 projected_onb_eigvals = projected_onb_eigvals[ambient_dim - manifold_dim :]
164 projected_onb_eigvecs = projected_onb_eigvecs[:, ambient_dim - manifold_dim :]
166 if not np.all(np.isclose(projected_onb_eigvals, 1.0)):
167 raise RuntimeError("Expected `projected_onb_eigvals` to be close to 1")
169 return projected_onb_eigvecs