Coverage for geometric_kernels/spaces/hypersphere.py: 99%

78 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`Hypersphere` space and the respective 

3:class:`~.eigenfunctions.Eigenfunctions` subclass :class:`SphericalHarmonics`. 

4""" 

5 

6import geomstats as gs 

7import lab as B 

8import numpy as np 

9from beartype.typing import List, Optional 

10from spherical_harmonics import SphericalHarmonics as _SphericalHarmonics 

11from spherical_harmonics.fundamental_set import num_harmonics 

12 

13from geometric_kernels.lab_extras import dtype_double 

14from geometric_kernels.spaces.base import DiscreteSpectrumSpace 

15from geometric_kernels.spaces.eigenfunctions import ( 

16 Eigenfunctions, 

17 EigenfunctionsWithAdditionTheorem, 

18) 

19from geometric_kernels.utils.utils import chain 

20 

21 

22class SphericalHarmonics(EigenfunctionsWithAdditionTheorem): 

23 r""" 

24 Eigenfunctions of the Laplace-Beltrami operator on the hypersphere 

25 correspond to the spherical harmonics. 

26 

27 Levels are the whole eigenspaces. 

28 

29 :param dim: 

30 Dimension of the hypersphere. 

31 

32 E.g. dim = 2 means the standard 2-dimensional sphere in $\mathbb{R}^3$. 

33 We only support dim >= 2. For dim = 1, use :class:`~.spaces.Circle`. 

34 

35 :param num_levels: 

36 Specifies the number of levels of the spherical harmonics. 

37 """ 

38 

39 def __init__(self, dim: int, num_levels: int) -> None: 

40 self.dim = dim 

41 self._num_levels = num_levels 

42 self._num_eigenfunctions: Optional[int] = None # To be computed when needed. 

43 self._spherical_harmonics = _SphericalHarmonics( 

44 dimension=dim + 1, 

45 degrees=self._num_levels, 

46 allow_uncomputed_levels=True, 

47 ) 

48 

49 @property 

50 def num_computed_levels(self) -> int: 

51 num = 0 

52 for level in self._spherical_harmonics.harmonic_levels: 

53 if level.is_level_computed: 

54 num += 1 

55 else: 

56 break 

57 return num 

58 

59 def __call__(self, X: B.Numeric, **kwargs) -> B.Numeric: 

60 return self._spherical_harmonics(X) 

61 

62 def _addition_theorem( 

63 self, X: B.Numeric, X2: Optional[B.Numeric] = None, **kwargs 

64 ) -> B.Numeric: 

65 r""" 

66 Returns the result of applying the addition theorem to sum over all 

67 the outer products of eigenfunctions within a level, for each level. 

68 

69 The case of the hypersphere is explicitly considered on 

70 doc:`this documentation page </theory/addition_theorem>`. 

71 

72 :param X: 

73 The first of the two batches of points to evaluate the phi 

74 product at. An array of shape [N, <axis>], where N is the 

75 number of points and <axis> is the shape of the arrays that 

76 represent the points in a given space. 

77 :param X2: 

78 The second of the two batches of points to evaluate the phi 

79 product at. An array of shape [N2, <axis>], where N2 is the 

80 number of points and <axis> is the shape of the arrays that 

81 represent the points in a given space. 

82 

83 Defaults to None, in which case X is used for X2. 

84 :param ``**kwargs``: 

85 Any additional parameters. 

86 

87 :return: 

88 An array of shape [N, N2, L]. 

89 """ 

90 values = [ 

91 level.addition(X, X2)[..., None] # [N, N2, 1] 

92 for level in self._spherical_harmonics.harmonic_levels 

93 ] 

94 return B.concat(*values, axis=-1) # [N, N2, L] 

95 

96 def _addition_theorem_diag(self, X: B.Numeric, **kwargs) -> B.Numeric: 

97 """ 

98 These are certain easy to compute constants. 

99 """ 

100 values = [ 

101 level.addition_at_1(X) # [N, 1] 

102 for level in self._spherical_harmonics.harmonic_levels 

103 ] 

104 return B.concat(*values, axis=1) # [N, L] 

105 

106 @property 

107 def num_eigenfunctions(self) -> int: 

108 if self._num_eigenfunctions is None: 

109 self._num_eigenfunctions = sum(self.num_eigenfunctions_per_level) 

110 return self._num_eigenfunctions 

111 

112 @property 

113 def num_levels(self) -> int: 

114 return self._num_levels 

115 

116 @property 

117 def num_eigenfunctions_per_level(self) -> List[int]: 

118 return [num_harmonics(self.dim + 1, level) for level in range(self.num_levels)] 

119 

120 

121class Hypersphere(DiscreteSpectrumSpace, gs.geometry.hypersphere.Hypersphere): 

122 r""" 

123 The GeometricKernels space representing the d-dimensional hypersphere 

124 $\mathbb{S}_d$ embedded in the (d+1)-dimensional Euclidean space. 

125 

126 The elements of this space are represented by (d+1)-dimensional vectors 

127 of unit norm. 

128 

129 Levels are the whole eigenspaces. 

130 

131 .. note:: 

132 We only support d >= 2. For d = 1, use :class:`~.spaces.Circle`. 

133 

134 .. note:: 

135 A tutorial on how to use this space is available in the 

136 :doc:`Hypersphere.ipynb </examples/Hypersphere>` notebook. 

137 

138 :param dim: 

139 Dimension of the hypersphere $\mathbb{S}_d$. 

140 Should satisfy dim >= 2. For dim = 1, use :class:`~.spaces.Circle`. 

141 

142 .. admonition:: Citation 

143 

144 If you use this GeometricKernels space in your research, please consider 

145 citing :cite:t:`borovitskiy2020`. 

146 """ 

147 

148 def __init__(self, dim: int): 

149 if dim < 2: 

150 raise ValueError("Only dim >= 2 is supported. For dim = 1, use Circle.") 

151 super().__init__(dim=dim) 

152 self.dim = dim 

153 

154 def __str__(self): 

155 return f"Hypersphere({self.dim})" 

156 

157 @property 

158 def dimension(self) -> int: 

159 """ 

160 Returns d, the `dim` parameter that was passed down to `__init__`. 

161 """ 

162 return self.dim 

163 

164 def get_eigenfunctions(self, num: int) -> Eigenfunctions: 

165 """ 

166 Returns the :class:`~.SphericalHarmonics` object with `num` levels. 

167 

168 :param num: 

169 Number of levels. 

170 """ 

171 return SphericalHarmonics(self.dim, num) 

172 

173 def get_eigenvalues(self, num: int) -> B.Numeric: 

174 eigenfunctions = SphericalHarmonics(self.dim, num) 

175 eigenvalues = np.array( 

176 [ 

177 level.eigenvalue() 

178 for level in eigenfunctions._spherical_harmonics.harmonic_levels 

179 ] 

180 ) 

181 return B.reshape(eigenvalues, -1, 1) # [num, 1] 

182 

183 def get_repeated_eigenvalues(self, num: int) -> B.Numeric: 

184 eigenfunctions = SphericalHarmonics(self.dim, num) 

185 eigenvalues_per_level = self.get_eigenvalues(num) 

186 

187 eigenvalues = chain( 

188 B.squeeze(eigenvalues_per_level), 

189 eigenfunctions.num_eigenfunctions_per_level, 

190 ) # [J,] 

191 return B.reshape(eigenvalues, -1, 1) # [J, 1] 

192 

193 def ehess2rhess( 

194 self, 

195 x: B.NPNumeric, 

196 egrad: B.NPNumeric, 

197 ehess: B.NPNumeric, 

198 direction: B.NPNumeric, 

199 ) -> B.NPNumeric: 

200 """ 

201 Riemannian Hessian along a given direction from the Euclidean Hessian. 

202 

203 Used to test that the heat kernel does indeed solve the heat equation. 

204 

205 :param x: 

206 A point on the d-dimensional hypersphere. 

207 :param egrad: 

208 Euclidean gradient of a function defined in a neighborhood of the 

209 hypersphere, evaluated at the point `x`. 

210 :param ehess: 

211 Euclidean Hessian of a function defined in a neighborhood of the 

212 hypersphere, evaluated at the point `x`. 

213 :param direction: 

214 Direction to evaluate the Riemannian Hessian at. A tangent vector 

215 at `x`. 

216 

217 :return: 

218 A [dim]-shaped array that contains Hess_f(x)[direction], the value 

219 of the Riemannian Hessian of the function evaluated at `x` along 

220 the `direction`. 

221 

222 See :cite:t:`absil2008` for mathematical details. 

223 """ 

224 normal_gradient = egrad - self.to_tangent(egrad, x) 

225 return ( 

226 self.to_tangent(ehess, x) 

227 - self.metric.inner_product(x, normal_gradient, x) * direction 

228 ) 

229 

230 def random(self, key: B.RandomState, number: int) -> B.Numeric: 

231 """ 

232 Sample uniformly random points on the hypersphere. 

233 

234 Always returns [N, D+1] float64 array of the `key`'s backend. 

235 

236 :param key: 

237 Either `np.random.RandomState`, `tf.random.Generator`, 

238 `torch.Generator` or `jax.tensor` (representing random state). 

239 :param number: 

240 Number N of samples to draw. 

241 

242 :return: 

243 An array of `number` uniformly random samples on the space. 

244 """ 

245 key, random_points = B.random.randn( 

246 key, dtype_double(key), number, self.dimension + 1 

247 ) # (N, d+1) 

248 random_points /= B.sqrt( 

249 B.sum(random_points**2, axis=1, squeeze=False) 

250 ) # (N, d+1) 

251 return key, random_points 

252 

253 @property 

254 def element_shape(self): 

255 """ 

256 :return: 

257 [d+1]. 

258 """ 

259 return [self.dimension + 1] 

260 

261 @property 

262 def element_dtype(self): 

263 """ 

264 :return: 

265 B.Float. 

266 """ 

267 return B.Float