Coverage for geometric_kernels/spaces/hypersphere.py: 99%
78 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provides the :class:`Hypersphere` space and the respective
3:class:`~.eigenfunctions.Eigenfunctions` subclass :class:`SphericalHarmonics`.
4"""
6import geomstats as gs
7import lab as B
8import numpy as np
9from beartype.typing import List, Optional
10from spherical_harmonics import SphericalHarmonics as _SphericalHarmonics
11from spherical_harmonics.fundamental_set import num_harmonics
13from geometric_kernels.lab_extras import dtype_double
14from geometric_kernels.spaces.base import DiscreteSpectrumSpace
15from geometric_kernels.spaces.eigenfunctions import (
16 Eigenfunctions,
17 EigenfunctionsWithAdditionTheorem,
18)
19from geometric_kernels.utils.utils import chain
22class SphericalHarmonics(EigenfunctionsWithAdditionTheorem):
23 r"""
24 Eigenfunctions of the Laplace-Beltrami operator on the hypersphere
25 correspond to the spherical harmonics.
27 Levels are the whole eigenspaces.
29 :param dim:
30 Dimension of the hypersphere.
32 E.g. dim = 2 means the standard 2-dimensional sphere in $\mathbb{R}^3$.
33 We only support dim >= 2. For dim = 1, use :class:`~.spaces.Circle`.
35 :param num_levels:
36 Specifies the number of levels of the spherical harmonics.
37 """
39 def __init__(self, dim: int, num_levels: int) -> None:
40 self.dim = dim
41 self._num_levels = num_levels
42 self._num_eigenfunctions: Optional[int] = None # To be computed when needed.
43 self._spherical_harmonics = _SphericalHarmonics(
44 dimension=dim + 1,
45 degrees=self._num_levels,
46 allow_uncomputed_levels=True,
47 )
49 @property
50 def num_computed_levels(self) -> int:
51 num = 0
52 for level in self._spherical_harmonics.harmonic_levels:
53 if level.is_level_computed:
54 num += 1
55 else:
56 break
57 return num
59 def __call__(self, X: B.Numeric, **kwargs) -> B.Numeric:
60 return self._spherical_harmonics(X)
62 def _addition_theorem(
63 self, X: B.Numeric, X2: Optional[B.Numeric] = None, **kwargs
64 ) -> B.Numeric:
65 r"""
66 Returns the result of applying the addition theorem to sum over all
67 the outer products of eigenfunctions within a level, for each level.
69 The case of the hypersphere is explicitly considered on
70 doc:`this documentation page </theory/addition_theorem>`.
72 :param X:
73 The first of the two batches of points to evaluate the phi
74 product at. An array of shape [N, <axis>], where N is the
75 number of points and <axis> is the shape of the arrays that
76 represent the points in a given space.
77 :param X2:
78 The second of the two batches of points to evaluate the phi
79 product at. An array of shape [N2, <axis>], where N2 is the
80 number of points and <axis> is the shape of the arrays that
81 represent the points in a given space.
83 Defaults to None, in which case X is used for X2.
84 :param ``**kwargs``:
85 Any additional parameters.
87 :return:
88 An array of shape [N, N2, L].
89 """
90 values = [
91 level.addition(X, X2)[..., None] # [N, N2, 1]
92 for level in self._spherical_harmonics.harmonic_levels
93 ]
94 return B.concat(*values, axis=-1) # [N, N2, L]
96 def _addition_theorem_diag(self, X: B.Numeric, **kwargs) -> B.Numeric:
97 """
98 These are certain easy to compute constants.
99 """
100 values = [
101 level.addition_at_1(X) # [N, 1]
102 for level in self._spherical_harmonics.harmonic_levels
103 ]
104 return B.concat(*values, axis=1) # [N, L]
106 @property
107 def num_eigenfunctions(self) -> int:
108 if self._num_eigenfunctions is None:
109 self._num_eigenfunctions = sum(self.num_eigenfunctions_per_level)
110 return self._num_eigenfunctions
112 @property
113 def num_levels(self) -> int:
114 return self._num_levels
116 @property
117 def num_eigenfunctions_per_level(self) -> List[int]:
118 return [num_harmonics(self.dim + 1, level) for level in range(self.num_levels)]
121class Hypersphere(DiscreteSpectrumSpace, gs.geometry.hypersphere.Hypersphere):
122 r"""
123 The GeometricKernels space representing the d-dimensional hypersphere
124 $\mathbb{S}_d$ embedded in the (d+1)-dimensional Euclidean space.
126 The elements of this space are represented by (d+1)-dimensional vectors
127 of unit norm.
129 Levels are the whole eigenspaces.
131 .. note::
132 We only support d >= 2. For d = 1, use :class:`~.spaces.Circle`.
134 .. note::
135 A tutorial on how to use this space is available in the
136 :doc:`Hypersphere.ipynb </examples/Hypersphere>` notebook.
138 :param dim:
139 Dimension of the hypersphere $\mathbb{S}_d$.
140 Should satisfy dim >= 2. For dim = 1, use :class:`~.spaces.Circle`.
142 .. admonition:: Citation
144 If you use this GeometricKernels space in your research, please consider
145 citing :cite:t:`borovitskiy2020`.
146 """
148 def __init__(self, dim: int):
149 if dim < 2:
150 raise ValueError("Only dim >= 2 is supported. For dim = 1, use Circle.")
151 super().__init__(dim=dim)
152 self.dim = dim
154 def __str__(self):
155 return f"Hypersphere({self.dim})"
157 @property
158 def dimension(self) -> int:
159 """
160 Returns d, the `dim` parameter that was passed down to `__init__`.
161 """
162 return self.dim
164 def get_eigenfunctions(self, num: int) -> Eigenfunctions:
165 """
166 Returns the :class:`~.SphericalHarmonics` object with `num` levels.
168 :param num:
169 Number of levels.
170 """
171 return SphericalHarmonics(self.dim, num)
173 def get_eigenvalues(self, num: int) -> B.Numeric:
174 eigenfunctions = SphericalHarmonics(self.dim, num)
175 eigenvalues = np.array(
176 [
177 level.eigenvalue()
178 for level in eigenfunctions._spherical_harmonics.harmonic_levels
179 ]
180 )
181 return B.reshape(eigenvalues, -1, 1) # [num, 1]
183 def get_repeated_eigenvalues(self, num: int) -> B.Numeric:
184 eigenfunctions = SphericalHarmonics(self.dim, num)
185 eigenvalues_per_level = self.get_eigenvalues(num)
187 eigenvalues = chain(
188 B.squeeze(eigenvalues_per_level),
189 eigenfunctions.num_eigenfunctions_per_level,
190 ) # [J,]
191 return B.reshape(eigenvalues, -1, 1) # [J, 1]
193 def ehess2rhess(
194 self,
195 x: B.NPNumeric,
196 egrad: B.NPNumeric,
197 ehess: B.NPNumeric,
198 direction: B.NPNumeric,
199 ) -> B.NPNumeric:
200 """
201 Riemannian Hessian along a given direction from the Euclidean Hessian.
203 Used to test that the heat kernel does indeed solve the heat equation.
205 :param x:
206 A point on the d-dimensional hypersphere.
207 :param egrad:
208 Euclidean gradient of a function defined in a neighborhood of the
209 hypersphere, evaluated at the point `x`.
210 :param ehess:
211 Euclidean Hessian of a function defined in a neighborhood of the
212 hypersphere, evaluated at the point `x`.
213 :param direction:
214 Direction to evaluate the Riemannian Hessian at. A tangent vector
215 at `x`.
217 :return:
218 A [dim]-shaped array that contains Hess_f(x)[direction], the value
219 of the Riemannian Hessian of the function evaluated at `x` along
220 the `direction`.
222 See :cite:t:`absil2008` for mathematical details.
223 """
224 normal_gradient = egrad - self.to_tangent(egrad, x)
225 return (
226 self.to_tangent(ehess, x)
227 - self.metric.inner_product(x, normal_gradient, x) * direction
228 )
230 def random(self, key: B.RandomState, number: int) -> B.Numeric:
231 """
232 Sample uniformly random points on the hypersphere.
234 Always returns [N, D+1] float64 array of the `key`'s backend.
236 :param key:
237 Either `np.random.RandomState`, `tf.random.Generator`,
238 `torch.Generator` or `jax.tensor` (representing random state).
239 :param number:
240 Number N of samples to draw.
242 :return:
243 An array of `number` uniformly random samples on the space.
244 """
245 key, random_points = B.random.randn(
246 key, dtype_double(key), number, self.dimension + 1
247 ) # (N, d+1)
248 random_points /= B.sqrt(
249 B.sum(random_points**2, axis=1, squeeze=False)
250 ) # (N, d+1)
251 return key, random_points
253 @property
254 def element_shape(self):
255 """
256 :return:
257 [d+1].
258 """
259 return [self.dimension + 1]
261 @property
262 def element_dtype(self):
263 """
264 :return:
265 B.Float.
266 """
267 return B.Float