Coverage for geometric_kernels/spaces/so.py: 88%

161 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`SpecialOrthogonal` space, the respective 

3:class:`~.eigenfunctions.Eigenfunctions` subclass :class:`SOEigenfunctions`, 

4and a class :class:`SOCharacter` for representing characters of the group. 

5""" 

6 

7import itertools 

8import json 

9import math 

10import operator 

11from functools import reduce 

12 

13import lab as B 

14import numpy as np 

15from beartype.typing import List, Tuple 

16 

17from geometric_kernels.lab_extras import ( 

18 complex_conj, 

19 complex_like, 

20 create_complex, 

21 dtype_double, 

22 from_numpy, 

23 qr, 

24 take_along_axis, 

25) 

26from geometric_kernels.spaces.eigenfunctions import Eigenfunctions 

27from geometric_kernels.spaces.lie_groups import ( 

28 CompactMatrixLieGroup, 

29 LieGroupCharacter, 

30 WeylAdditionTheorem, 

31) 

32from geometric_kernels.utils.utils import ( 

33 chain, 

34 fixed_length_partitions, 

35 get_resource_file_path, 

36) 

37 

38 

39class SOEigenfunctions(WeylAdditionTheorem): 

40 def __init__(self, n, num_levels, compute_characters=True): 

41 self.n = n 

42 self.dim = n * (n - 1) // 2 

43 self.rank = n // 2 

44 

45 if self.n % 2 == 0: 

46 self.rho = np.arange(self.rank - 1, -1, -1) 

47 else: 

48 self.rho = np.arange(self.rank - 1, -1, -1) + 0.5 

49 

50 super().__init__(n, num_levels, compute_characters) 

51 

52 def _generate_signatures(self, num_levels: int) -> List[Tuple[int, ...]]: 

53 signatures = [] 

54 # largest LB eigenvalues correspond to partitions of smallest integers 

55 # IF p >> k the number of partitions of p into k parts is O(p^k) 

56 if self.n == 3: 

57 # in this case rank=1, so all partitions are trivial 

58 # and LB eigenvalue of the signature corresponding to p is p 

59 # 200 is more than enough, since the contribution of such term 

60 # even in Matern(0.5) case is less than 200^{-2} 

61 SIGNATURE_SUM = 200 

62 else: 

63 # Roughly speaking this is a bruteforce search through 50^{self.rank} smallest eigenvalues 

64 SIGNATURE_SUM = 50 

65 for signature_sum in range(0, SIGNATURE_SUM): 

66 for i in range(0, self.rank + 1): 

67 for signature in fixed_length_partitions(signature_sum, i): 

68 signature.extend([0] * (self.rank - i)) 

69 signatures.append(tuple(signature)) 

70 if self.n % 2 == 0 and signature[-1] != 0: 

71 signature[-1] = -signature[-1] 

72 signatures.append(tuple(signature)) 

73 

74 eig_and_signature = [ 

75 (round(4 * self._compute_eigenvalue(signature)), signature) 

76 for signature in signatures 

77 ] 

78 

79 eig_and_signature.sort() 

80 signatures = [eig_sgn[1] for eig_sgn in eig_and_signature][:num_levels] 

81 return signatures 

82 

83 def _compute_dimension(self, signature: Tuple[int, ...]) -> int: 

84 if self.n % 2 == 1: 

85 qs = [pk + self.rank - k - 1 / 2 for k, pk in enumerate(signature)] 

86 rep_dim = reduce( 

87 operator.mul, 

88 (2 * qs[k] / math.factorial(2 * k + 1) for k in range(0, self.rank)), 

89 ) * reduce( 

90 operator.mul, 

91 ( 

92 (qs[i] - qs[j]) * (qs[i] + qs[j]) 

93 for i, j in itertools.combinations(range(self.rank), 2) 

94 ), 

95 1, 

96 ) 

97 return int(round(rep_dim)) 

98 else: 

99 qs = [ 

100 pk + self.rank - k - 1 if k != self.rank - 1 else abs(pk) 

101 for k, pk in enumerate(signature) 

102 ] 

103 rep_dim = int( 

104 reduce( 

105 operator.mul, 

106 (2 / math.factorial(2 * k) for k in range(1, self.rank)), 

107 ) 

108 * reduce( 

109 operator.mul, 

110 ( 

111 (qs[i] - qs[j]) * (qs[i] + qs[j]) 

112 for i, j in itertools.combinations(range(self.rank), 2) 

113 ), 

114 1, 

115 ) 

116 ) 

117 return int(round(rep_dim)) 

118 

119 def _compute_eigenvalue(self, signature: Tuple[int, ...]) -> B.Float: 

120 np_sgn = np.array(signature) 

121 rho = self.rho 

122 eigenvalue = np.linalg.norm(rho + np_sgn) ** 2 - np.linalg.norm(rho) ** 2 

123 return eigenvalue 

124 

125 def _compute_character( 

126 self, n: int, signature: Tuple[int, ...] 

127 ) -> LieGroupCharacter: 

128 return SOCharacter(n, signature) 

129 

130 def _torus_representative(self, X: B.Numeric) -> B.Numeric: 

131 gamma = None 

132 

133 if self.n == 3: 

134 # In SO(3) the torus representative is determined by the non-trivial pair of eigenvalues, 

135 # which can be calculated from the trace 

136 trace = B.expand_dims(B.trace(X), axis=-1) 

137 real = (trace - 1) / 2 

138 zeros = real * 0 

139 imag = B.sqrt(B.maximum(1 - real * real, zeros)) 

140 gamma = create_complex(real, imag) 

141 elif self.n % 2 == 1: 

142 # In SO(2n+1) the torus representative is determined by the (unordered) non-trivial eigenvalues 

143 eigvals = B.eig(X, False) 

144 sorted_ind = B.argsort(B.real(eigvals), axis=-1) 

145 eigvals = take_along_axis(eigvals, sorted_ind, -1) 

146 gamma = eigvals[..., 0:-1:2] 

147 else: 

148 # In SO(2n) each unordered set of eigenvalues determines two conjugacy classes 

149 eigvals, eigvecs = B.eig(X) 

150 sorted_ind = B.argsort(B.real(eigvals), axis=-1) 

151 eigvals = take_along_axis(eigvals, sorted_ind, -1) 

152 eigvecs = take_along_axis( 

153 eigvecs, 

154 B.broadcast_to(B.expand_dims(sorted_ind, axis=-2), *eigvecs.shape), 

155 -1, 

156 ) 

157 # c is a matrix transforming x into its canonical form (with 2x2 blocks) 

158 c = B.reshape( 

159 B.stack( 

160 eigvecs[..., ::2] + eigvecs[..., 1::2], 

161 eigvecs[..., ::2] - eigvecs[..., 1::2], 

162 axis=-1, 

163 ), 

164 *eigvecs.shape[:-1], 

165 -1, 

166 ) 

167 # eigenvectors calculated by LAPACK are either real or purely imaginary, make everything real 

168 # WARNING: might depend on the implementation of the eigendecomposition! 

169 c = B.real(c) + B.imag(c) 

170 # normalize s.t. det(c)≈±1, probably unnecessary 

171 c /= math.sqrt(2) 

172 eigvals = B.concat( 

173 B.expand_dims( 

174 B.power(eigvals[..., 0], B.cast(complex_like(c), B.sign(B.det(c)))), 

175 axis=-1, 

176 ), 

177 eigvals[..., 1:], 

178 axis=-1, 

179 ) 

180 gamma = eigvals[..., ::2] 

181 gamma = B.concat(gamma, complex_conj(gamma), axis=-1) 

182 return gamma 

183 

184 def inverse(self, X: B.Numeric) -> B.Numeric: 

185 return SpecialOrthogonal.inverse(X) 

186 

187 

188class SOCharacter(LieGroupCharacter): 

189 """ 

190 The class that represents a character of the SO(n) group. 

191 

192 These characters are always real-valued. 

193 

194 These are polynomials whose coefficients are precomputed and stored in a 

195 file. By default, there are 20 precomputed characters for n from 3 to 8. 

196 If you want more, use the `compute_characters.py` script. 

197 

198 :param n: 

199 The order n of the SO(n) group. 

200 :param signature: 

201 The signature that determines a particular character (and an 

202 irreducible unitary representation along with it). 

203 """ 

204 

205 def __init__(self, n: int, signature: Tuple[int, ...]): 

206 self.signature = signature 

207 self.n = n 

208 self.coeffs, self.monoms = self._load() 

209 

210 def _load(self): 

211 group_name = "SO({})".format(self.n) 

212 with get_resource_file_path("precomputed_characters.json") as file_path: 

213 with file_path.open("r") as file: 

214 character_formulas = json.load(file) 

215 try: 

216 cs, ms = character_formulas[group_name][str(self.signature)] 

217 coeffs, monoms = (np.array(data) for data in (cs, ms)) 

218 return coeffs, monoms 

219 except KeyError as e: 

220 raise KeyError( 

221 "Unable to retrieve character parameters for signature {} of {}, " 

222 "perhaps it is not precomputed." 

223 "Run compute_characters.py with changed parameters.".format( 

224 e.args[0], group_name 

225 ) 

226 ) from None 

227 

228 def __call__(self, gammas: B.Numeric) -> B.Numeric: 

229 char_val = B.zeros(B.dtype(gammas), *gammas.shape[:-1]) 

230 for coeff, monom in zip(self.coeffs, self.monoms): 

231 char_val += coeff * B.prod( 

232 gammas ** B.cast(B.dtype(gammas), from_numpy(gammas, monom)), axis=-1 

233 ) 

234 return char_val 

235 

236 

237class SpecialOrthogonal(CompactMatrixLieGroup): 

238 r""" 

239 The GeometricKernels space representing the special orthogonal group SO(n) 

240 consisting of n by n orthogonal matrices with unit determinant. 

241 

242 The elements of this space are represented as n x n orthogonal 

243 matrices with real entries and unit determinant. 

244 

245 .. note:: 

246 A tutorial on how to use this space is available in the 

247 :doc:`SpecialOrthogonal.ipynb </examples/SpecialOrthogonal>` notebook. 

248 

249 :param n: 

250 The order n of the group SO(n). 

251 

252 .. note:: 

253 We only support n >= 3. Mathematically, SO(2) is equivalent to the 

254 unit circle, which is available as the :class:`~.spaces.Circle` space. 

255 

256 For larger values of n, you might need to run the 

257 `utils/compute_characters.py` script to precompute the necessary 

258 mathematical quantities beyond the ones provided by default. Same 

259 can be required for larger numbers of levels. 

260 

261 .. admonition:: Citation 

262 

263 If you use this GeometricKernels space in your research, please consider 

264 citing :cite:t:`azangulov2024a`. 

265 """ 

266 

267 def __init__(self, n: int): 

268 if n < 3: 

269 raise ValueError("Only n >= 3 is supported. For n = 2, use Circle.") 

270 self.n = n 

271 self.dim = n * (n - 1) // 2 

272 self.rank = n // 2 

273 super().__init__() 

274 

275 def __str__(self): 

276 return f"SpecialOrthogonal({self.n})" 

277 

278 @property 

279 def dimension(self) -> int: 

280 """ 

281 The dimension of the space, as that of a Riemannian manifold. 

282 

283 :return: 

284 floor(n(n-1)/2) where n is the order of the group SO(n). 

285 """ 

286 return self.dim 

287 

288 @staticmethod 

289 def inverse(X: B.Numeric) -> B.Numeric: 

290 return B.transpose(X) # B.transpose only inverses the last two dims. 

291 

292 def get_eigenfunctions(self, num: int) -> Eigenfunctions: 

293 """ 

294 Returns the :class:`~.SOEigenfunctions` object with `num` levels 

295 and order n. 

296 

297 :param num: 

298 Number of levels. 

299 """ 

300 return SOEigenfunctions(self.n, num) 

301 

302 def get_eigenvalues(self, num: int) -> B.Numeric: 

303 eigenfunctions = SOEigenfunctions(self.n, num) 

304 eigenvalues = np.array( 

305 [eigenvalue for eigenvalue in eigenfunctions._eigenvalues] 

306 ) 

307 return B.reshape(eigenvalues, -1, 1) # [num, 1] 

308 

309 def get_repeated_eigenvalues(self, num: int) -> B.Numeric: 

310 eigenfunctions = SOEigenfunctions(self.n, num) 

311 eigenvalues = chain( 

312 eigenfunctions._eigenvalues, 

313 [rep_dim**2 for rep_dim in eigenfunctions._dimensions], 

314 ) 

315 return B.reshape(eigenvalues, -1, 1) # [J, 1] 

316 

317 def random(self, key: B.RandomState, number: int): 

318 if self.n == 2: # for the bright future where we support SO(2). 

319 # SO(2) = S^1 

320 key, thetas = B.random.randn(key, dtype_double(key), number, 1) 

321 thetas = 2 * math.pi * thetas 

322 c = B.cos(thetas) 

323 s = B.sin(thetas) 

324 r1 = B.stack(c, s, axis=-1) 

325 r2 = B.stack(-s, c, axis=-1) 

326 q = B.concat(r1, r2, axis=-2) 

327 return key, q 

328 elif self.n == 3: 

329 # explicit parametrization via the double cover SU(2) = S^3 

330 key, sphere_point = B.random.randn(key, dtype_double(key), number, 4) 

331 sphere_point /= B.reshape( 

332 B.sqrt(B.einsum("ij,ij->i", sphere_point, sphere_point)), -1, 1 

333 ) 

334 

335 x, y, z, w = (B.reshape(sphere_point[..., i], -1, 1) for i in range(4)) 

336 xx, yy, zz = x**2, y**2, z**2 

337 xy, xz, xw, yz, yw, zw = x * y, x * z, x * w, y * z, y * w, z * w 

338 

339 r1 = B.stack(1 - 2 * (yy + zz), 2 * (xy - zw), 2 * (xz + yw), axis=1) 

340 r2 = B.stack(2 * (xy + zw), 1 - 2 * (xx + zz), 2 * (yz - xw), axis=1) 

341 r3 = B.stack(2 * (xz - yw), 2 * (yz + xw), 1 - 2 * (xx + yy), axis=1) 

342 

343 q = B.concat(r1, r2, r3, axis=-1) 

344 return key, q 

345 else: 

346 # qr decomposition is not in the lab package, so numpy is used. 

347 key, h = B.random.randn(key, dtype_double(key), number, self.n, self.n) 

348 q, r = qr(h, mode="complete") 

349 r_diag_sign = B.sign(B.einsum("...ii->...i", r)) 

350 q *= r_diag_sign[:, None] 

351 q_det_sign = B.sign(B.det(q)) 

352 q_new = q[:, :, 0] * q_det_sign[:, None] 

353 q_new = B.concat(q_new[:, :, None], q[:, :, 1:], axis=-1) 

354 return key, q_new 

355 

356 @property 

357 def element_shape(self): 

358 """ 

359 :return: 

360 [n, n]. 

361 """ 

362 return [self.n, self.n] 

363 

364 @property 

365 def element_dtype(self): 

366 """ 

367 :return: 

368 B.Float. 

369 """ 

370 return B.Float