Coverage for geometric_kernels/utils/manifold_utils.py: 91%

56 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1"""Utilities for dealing with manifolds.""" 

2 

3import lab as B 

4import numpy as np 

5from beartype.typing import Optional 

6 

7from geometric_kernels.lab_extras import from_numpy 

8 

9 

10def minkowski_inner_product(vector_a: B.Numeric, vector_b: B.Numeric) -> B.Numeric: 

11 r""" 

12 Computes the Minkowski inner product of vectors. 

13 

14 .. math:: \langle a, b \rangle = a_0 b_0 - a_1 b_1 - \ldots - a_n b_n. 

15 

16 :param vector_a: 

17 An [..., n+1]-shaped array of points in the hyperbolic space $\mathbb{H}_n$. 

18 :param vector_b: 

19 An [..., n+1]-shaped array of points in the hyperbolic space $\mathbb{H}_n$. 

20 

21 :return: 

22 An [...,]-shaped array of inner products. 

23 """ 

24 if B.shape(vector_a) != B.shape(vector_b): 

25 raise ValueError("`vector_a` and `vector_b` must have the same shapes.") 

26 n = vector_a.shape[-1] - 1 

27 if n == 0: 

28 raise ValueError("Must have at least 1 point.") 

29 diagonal = from_numpy(vector_a, [-1.0] + [1.0] * n) # (n+1) 

30 diagonal = B.cast(B.dtype(vector_a), diagonal) 

31 return B.einsum("...i,...i->...", diagonal * vector_a, vector_b) 

32 

33 

34def hyperbolic_distance( 

35 x1: B.Numeric, x2: B.Numeric, diag: Optional[bool] = False 

36) -> B.Numeric: 

37 """ 

38 Compute the hyperbolic distance between `x1` and `x2`. 

39 

40 The code is a reimplementation of 

41 `geomstats.geometry.hyperboloid.HyperbolicMetric` for `lab`. 

42 

43 :param x1: 

44 An [N, n+1]-shaped array of points in the hyperbolic space. 

45 :param x2: 

46 An [M, n+1]-shaped array of points in the hyperbolic space. 

47 :param diag: 

48 If True, compute elementwise distance. Requires N = M. 

49 

50 Default False. 

51 

52 :return: 

53 An [N, M]-shaped array if diag=False or [N,]-shaped array 

54 if diag=True. 

55 """ 

56 if diag: 

57 # Compute a pointwise distance between `x1` and `x2` 

58 x1_ = x1 

59 x2_ = x2 

60 else: 

61 if B.rank(x1) == 1: 

62 x1 = B.expand_dims(x1) 

63 if B.rank(x2) == 1: 

64 x2 = B.expand_dims(x2) 

65 

66 # compute pairwise distance between arrays of points `x1` and `x2` 

67 # `x1` (N, n+1) 

68 # `x2` (M, n+1) 

69 x1_ = B.tile(x1[..., None, :], 1, x2.shape[0], 1) # (N, M, n+1) 

70 x2_ = B.tile(x2[None], x1.shape[0], 1, 1) # (N, M, n+1) 

71 

72 sq_norm_1 = minkowski_inner_product(x1_, x1_) 

73 sq_norm_2 = minkowski_inner_product(x2_, x2_) 

74 inner_prod = minkowski_inner_product(x1_, x2_) 

75 

76 cosh_angle = -inner_prod / B.sqrt(sq_norm_1 * sq_norm_2) 

77 

78 one = B.cast(B.dtype(cosh_angle), from_numpy(cosh_angle, [1.0])) 

79 large_constant = B.cast(B.dtype(cosh_angle), from_numpy(cosh_angle, [1e24])) 

80 

81 # clip values into [1.0, 1e24] 

82 cosh_angle = B.where(cosh_angle < one, one, cosh_angle) 

83 cosh_angle = B.where(cosh_angle > large_constant, large_constant, cosh_angle) 

84 

85 dist = B.log(cosh_angle + B.sqrt(cosh_angle**2 - 1)) # arccosh 

86 dist = B.cast(B.dtype(x1_), dist) 

87 return dist 

88 

89 

90def manifold_laplacian(x: B.Numeric, manifold, egrad, ehess): 

91 r""" 

92 Computes the manifold Laplacian of a given function at a given point x. 

93 The manifold Laplacian equals the trace of the manifold Hessian, i.e., 

94 $\Delta_M f(x) = \sum_{i=1}^{d} \nabla^2 f(x_i, x_i)$, where 

95 $[x_i]_{i=1}^{d}$ is an orthonormal basis of the tangent space at x. 

96 

97 .. warning:: 

98 This function only works for hyperspheres out of the box. We will 

99 need to change that in the future. 

100 

101 .. todo:: 

102 See warning above. 

103 

104 :param x: 

105 A point on the manifold at which to compute the Laplacian. 

106 :param manifold: 

107 A geomstats manifold. 

108 :param egrad: 

109 Euclidean gradient of the given function at x. 

110 :param ehess: 

111 Euclidean Hessian of the given function at x. 

112 

113 :return: 

114 Manifold Laplacian of the given function at x. 

115 

116 See :cite:t:`jost2011` (Chapter 3.1) for mathematical details. 

117 """ 

118 dim = manifold.dim 

119 

120 onb = tangent_onb(manifold, B.to_numpy(x)) 

121 result = 0.0 

122 for j in range(dim): 

123 cur_vec = onb[:, j] 

124 egrad_x = B.to_numpy(egrad(x)) 

125 ehess_x = B.to_numpy(ehess(x, from_numpy(x, cur_vec))) 

126 hess_vec_prod = manifold.ehess2rhess(B.to_numpy(x), egrad_x, ehess_x, cur_vec) 

127 result += manifold.metric.inner_product( 

128 hess_vec_prod, cur_vec, base_point=B.to_numpy(x) 

129 ) 

130 

131 return result 

132 

133 

134def tangent_onb(manifold, x): 

135 r""" 

136 Computes an orthonormal basis on the tangent space at x. 

137 

138 .. warning:: 

139 This function only works for hyperspheres out of the box. We will 

140 need to change that in the future. 

141 

142 .. todo:: 

143 See warning above. 

144 

145 :param manifold: 

146 A geomstats manifold. 

147 :param x: 

148 A point on the manifold. 

149 

150 :return: 

151 An [d, d]-shaped array containing the orthonormal basis 

152 on `manifold` at `x`. 

153 """ 

154 ambient_dim = manifold.dim + 1 

155 manifold_dim = manifold.dim 

156 ambient_onb = np.eye(ambient_dim) 

157 

158 projected_onb = manifold.to_tangent(ambient_onb, base_point=x) 

159 

160 projected_onb_eigvals, projected_onb_eigvecs = np.linalg.eigh(projected_onb) 

161 

162 # Getting rid of the zero eigenvalues: 

163 projected_onb_eigvals = projected_onb_eigvals[ambient_dim - manifold_dim :] 

164 projected_onb_eigvecs = projected_onb_eigvecs[:, ambient_dim - manifold_dim :] 

165 

166 if not np.all(np.isclose(projected_onb_eigvals, 1.0)): 

167 raise RuntimeError("Expected `projected_onb_eigvals` to be close to 1") 

168 

169 return projected_onb_eigvecs