Coverage for geometric_kernels/spaces/so.py: 88%
161 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provides the :class:`SpecialOrthogonal` space, the respective
3:class:`~.eigenfunctions.Eigenfunctions` subclass :class:`SOEigenfunctions`,
4and a class :class:`SOCharacter` for representing characters of the group.
5"""
7import itertools
8import json
9import math
10import operator
11from functools import reduce
13import lab as B
14import numpy as np
15from beartype.typing import List, Tuple
17from geometric_kernels.lab_extras import (
18 complex_conj,
19 complex_like,
20 create_complex,
21 dtype_double,
22 from_numpy,
23 qr,
24 take_along_axis,
25)
26from geometric_kernels.spaces.eigenfunctions import Eigenfunctions
27from geometric_kernels.spaces.lie_groups import (
28 CompactMatrixLieGroup,
29 LieGroupCharacter,
30 WeylAdditionTheorem,
31)
32from geometric_kernels.utils.utils import (
33 chain,
34 fixed_length_partitions,
35 get_resource_file_path,
36)
39class SOEigenfunctions(WeylAdditionTheorem):
40 def __init__(self, n, num_levels, compute_characters=True):
41 self.n = n
42 self.dim = n * (n - 1) // 2
43 self.rank = n // 2
45 if self.n % 2 == 0:
46 self.rho = np.arange(self.rank - 1, -1, -1)
47 else:
48 self.rho = np.arange(self.rank - 1, -1, -1) + 0.5
50 super().__init__(n, num_levels, compute_characters)
52 def _generate_signatures(self, num_levels: int) -> List[Tuple[int, ...]]:
53 signatures = []
54 # largest LB eigenvalues correspond to partitions of smallest integers
55 # IF p >> k the number of partitions of p into k parts is O(p^k)
56 if self.n == 3:
57 # in this case rank=1, so all partitions are trivial
58 # and LB eigenvalue of the signature corresponding to p is p
59 # 200 is more than enough, since the contribution of such term
60 # even in Matern(0.5) case is less than 200^{-2}
61 SIGNATURE_SUM = 200
62 else:
63 # Roughly speaking this is a bruteforce search through 50^{self.rank} smallest eigenvalues
64 SIGNATURE_SUM = 50
65 for signature_sum in range(0, SIGNATURE_SUM):
66 for i in range(0, self.rank + 1):
67 for signature in fixed_length_partitions(signature_sum, i):
68 signature.extend([0] * (self.rank - i))
69 signatures.append(tuple(signature))
70 if self.n % 2 == 0 and signature[-1] != 0:
71 signature[-1] = -signature[-1]
72 signatures.append(tuple(signature))
74 eig_and_signature = [
75 (round(4 * self._compute_eigenvalue(signature)), signature)
76 for signature in signatures
77 ]
79 eig_and_signature.sort()
80 signatures = [eig_sgn[1] for eig_sgn in eig_and_signature][:num_levels]
81 return signatures
83 def _compute_dimension(self, signature: Tuple[int, ...]) -> int:
84 if self.n % 2 == 1:
85 qs = [pk + self.rank - k - 1 / 2 for k, pk in enumerate(signature)]
86 rep_dim = reduce(
87 operator.mul,
88 (2 * qs[k] / math.factorial(2 * k + 1) for k in range(0, self.rank)),
89 ) * reduce(
90 operator.mul,
91 (
92 (qs[i] - qs[j]) * (qs[i] + qs[j])
93 for i, j in itertools.combinations(range(self.rank), 2)
94 ),
95 1,
96 )
97 return int(round(rep_dim))
98 else:
99 qs = [
100 pk + self.rank - k - 1 if k != self.rank - 1 else abs(pk)
101 for k, pk in enumerate(signature)
102 ]
103 rep_dim = int(
104 reduce(
105 operator.mul,
106 (2 / math.factorial(2 * k) for k in range(1, self.rank)),
107 )
108 * reduce(
109 operator.mul,
110 (
111 (qs[i] - qs[j]) * (qs[i] + qs[j])
112 for i, j in itertools.combinations(range(self.rank), 2)
113 ),
114 1,
115 )
116 )
117 return int(round(rep_dim))
119 def _compute_eigenvalue(self, signature: Tuple[int, ...]) -> B.Float:
120 np_sgn = np.array(signature)
121 rho = self.rho
122 eigenvalue = np.linalg.norm(rho + np_sgn) ** 2 - np.linalg.norm(rho) ** 2
123 return eigenvalue
125 def _compute_character(
126 self, n: int, signature: Tuple[int, ...]
127 ) -> LieGroupCharacter:
128 return SOCharacter(n, signature)
130 def _torus_representative(self, X: B.Numeric) -> B.Numeric:
131 gamma = None
133 if self.n == 3:
134 # In SO(3) the torus representative is determined by the non-trivial pair of eigenvalues,
135 # which can be calculated from the trace
136 trace = B.expand_dims(B.trace(X), axis=-1)
137 real = (trace - 1) / 2
138 zeros = real * 0
139 imag = B.sqrt(B.maximum(1 - real * real, zeros))
140 gamma = create_complex(real, imag)
141 elif self.n % 2 == 1:
142 # In SO(2n+1) the torus representative is determined by the (unordered) non-trivial eigenvalues
143 eigvals = B.eig(X, False)
144 sorted_ind = B.argsort(B.real(eigvals), axis=-1)
145 eigvals = take_along_axis(eigvals, sorted_ind, -1)
146 gamma = eigvals[..., 0:-1:2]
147 else:
148 # In SO(2n) each unordered set of eigenvalues determines two conjugacy classes
149 eigvals, eigvecs = B.eig(X)
150 sorted_ind = B.argsort(B.real(eigvals), axis=-1)
151 eigvals = take_along_axis(eigvals, sorted_ind, -1)
152 eigvecs = take_along_axis(
153 eigvecs,
154 B.broadcast_to(B.expand_dims(sorted_ind, axis=-2), *eigvecs.shape),
155 -1,
156 )
157 # c is a matrix transforming x into its canonical form (with 2x2 blocks)
158 c = B.reshape(
159 B.stack(
160 eigvecs[..., ::2] + eigvecs[..., 1::2],
161 eigvecs[..., ::2] - eigvecs[..., 1::2],
162 axis=-1,
163 ),
164 *eigvecs.shape[:-1],
165 -1,
166 )
167 # eigenvectors calculated by LAPACK are either real or purely imaginary, make everything real
168 # WARNING: might depend on the implementation of the eigendecomposition!
169 c = B.real(c) + B.imag(c)
170 # normalize s.t. det(c)≈±1, probably unnecessary
171 c /= math.sqrt(2)
172 eigvals = B.concat(
173 B.expand_dims(
174 B.power(eigvals[..., 0], B.cast(complex_like(c), B.sign(B.det(c)))),
175 axis=-1,
176 ),
177 eigvals[..., 1:],
178 axis=-1,
179 )
180 gamma = eigvals[..., ::2]
181 gamma = B.concat(gamma, complex_conj(gamma), axis=-1)
182 return gamma
184 def inverse(self, X: B.Numeric) -> B.Numeric:
185 return SpecialOrthogonal.inverse(X)
188class SOCharacter(LieGroupCharacter):
189 """
190 The class that represents a character of the SO(n) group.
192 These characters are always real-valued.
194 These are polynomials whose coefficients are precomputed and stored in a
195 file. By default, there are 20 precomputed characters for n from 3 to 8.
196 If you want more, use the `compute_characters.py` script.
198 :param n:
199 The order n of the SO(n) group.
200 :param signature:
201 The signature that determines a particular character (and an
202 irreducible unitary representation along with it).
203 """
205 def __init__(self, n: int, signature: Tuple[int, ...]):
206 self.signature = signature
207 self.n = n
208 self.coeffs, self.monoms = self._load()
210 def _load(self):
211 group_name = "SO({})".format(self.n)
212 with get_resource_file_path("precomputed_characters.json") as file_path:
213 with file_path.open("r") as file:
214 character_formulas = json.load(file)
215 try:
216 cs, ms = character_formulas[group_name][str(self.signature)]
217 coeffs, monoms = (np.array(data) for data in (cs, ms))
218 return coeffs, monoms
219 except KeyError as e:
220 raise KeyError(
221 "Unable to retrieve character parameters for signature {} of {}, "
222 "perhaps it is not precomputed."
223 "Run compute_characters.py with changed parameters.".format(
224 e.args[0], group_name
225 )
226 ) from None
228 def __call__(self, gammas: B.Numeric) -> B.Numeric:
229 char_val = B.zeros(B.dtype(gammas), *gammas.shape[:-1])
230 for coeff, monom in zip(self.coeffs, self.monoms):
231 char_val += coeff * B.prod(
232 gammas ** B.cast(B.dtype(gammas), from_numpy(gammas, monom)), axis=-1
233 )
234 return char_val
237class SpecialOrthogonal(CompactMatrixLieGroup):
238 r"""
239 The GeometricKernels space representing the special orthogonal group SO(n)
240 consisting of n by n orthogonal matrices with unit determinant.
242 The elements of this space are represented as n x n orthogonal
243 matrices with real entries and unit determinant.
245 .. note::
246 A tutorial on how to use this space is available in the
247 :doc:`SpecialOrthogonal.ipynb </examples/SpecialOrthogonal>` notebook.
249 :param n:
250 The order n of the group SO(n).
252 .. note::
253 We only support n >= 3. Mathematically, SO(2) is equivalent to the
254 unit circle, which is available as the :class:`~.spaces.Circle` space.
256 For larger values of n, you might need to run the
257 `utils/compute_characters.py` script to precompute the necessary
258 mathematical quantities beyond the ones provided by default. Same
259 can be required for larger numbers of levels.
261 .. admonition:: Citation
263 If you use this GeometricKernels space in your research, please consider
264 citing :cite:t:`azangulov2024a`.
265 """
267 def __init__(self, n: int):
268 if n < 3:
269 raise ValueError("Only n >= 3 is supported. For n = 2, use Circle.")
270 self.n = n
271 self.dim = n * (n - 1) // 2
272 self.rank = n // 2
273 super().__init__()
275 def __str__(self):
276 return f"SpecialOrthogonal({self.n})"
278 @property
279 def dimension(self) -> int:
280 """
281 The dimension of the space, as that of a Riemannian manifold.
283 :return:
284 floor(n(n-1)/2) where n is the order of the group SO(n).
285 """
286 return self.dim
288 @staticmethod
289 def inverse(X: B.Numeric) -> B.Numeric:
290 return B.transpose(X) # B.transpose only inverses the last two dims.
292 def get_eigenfunctions(self, num: int) -> Eigenfunctions:
293 """
294 Returns the :class:`~.SOEigenfunctions` object with `num` levels
295 and order n.
297 :param num:
298 Number of levels.
299 """
300 return SOEigenfunctions(self.n, num)
302 def get_eigenvalues(self, num: int) -> B.Numeric:
303 eigenfunctions = SOEigenfunctions(self.n, num)
304 eigenvalues = np.array(
305 [eigenvalue for eigenvalue in eigenfunctions._eigenvalues]
306 )
307 return B.reshape(eigenvalues, -1, 1) # [num, 1]
309 def get_repeated_eigenvalues(self, num: int) -> B.Numeric:
310 eigenfunctions = SOEigenfunctions(self.n, num)
311 eigenvalues = chain(
312 eigenfunctions._eigenvalues,
313 [rep_dim**2 for rep_dim in eigenfunctions._dimensions],
314 )
315 return B.reshape(eigenvalues, -1, 1) # [J, 1]
317 def random(self, key: B.RandomState, number: int):
318 if self.n == 2: # for the bright future where we support SO(2).
319 # SO(2) = S^1
320 key, thetas = B.random.randn(key, dtype_double(key), number, 1)
321 thetas = 2 * math.pi * thetas
322 c = B.cos(thetas)
323 s = B.sin(thetas)
324 r1 = B.stack(c, s, axis=-1)
325 r2 = B.stack(-s, c, axis=-1)
326 q = B.concat(r1, r2, axis=-2)
327 return key, q
328 elif self.n == 3:
329 # explicit parametrization via the double cover SU(2) = S^3
330 key, sphere_point = B.random.randn(key, dtype_double(key), number, 4)
331 sphere_point /= B.reshape(
332 B.sqrt(B.einsum("ij,ij->i", sphere_point, sphere_point)), -1, 1
333 )
335 x, y, z, w = (B.reshape(sphere_point[..., i], -1, 1) for i in range(4))
336 xx, yy, zz = x**2, y**2, z**2
337 xy, xz, xw, yz, yw, zw = x * y, x * z, x * w, y * z, y * w, z * w
339 r1 = B.stack(1 - 2 * (yy + zz), 2 * (xy - zw), 2 * (xz + yw), axis=1)
340 r2 = B.stack(2 * (xy + zw), 1 - 2 * (xx + zz), 2 * (yz - xw), axis=1)
341 r3 = B.stack(2 * (xz - yw), 2 * (yz + xw), 1 - 2 * (xx + yy), axis=1)
343 q = B.concat(r1, r2, r3, axis=-1)
344 return key, q
345 else:
346 # qr decomposition is not in the lab package, so numpy is used.
347 key, h = B.random.randn(key, dtype_double(key), number, self.n, self.n)
348 q, r = qr(h, mode="complete")
349 r_diag_sign = B.sign(B.einsum("...ii->...i", r))
350 q *= r_diag_sign[:, None]
351 q_det_sign = B.sign(B.det(q))
352 q_new = q[:, :, 0] * q_det_sign[:, None]
353 q_new = B.concat(q_new[:, :, None], q[:, :, 1:], axis=-1)
354 return key, q_new
356 @property
357 def element_shape(self):
358 """
359 :return:
360 [n, n].
361 """
362 return [self.n, self.n]
364 @property
365 def element_dtype(self):
366 """
367 :return:
368 B.Float.
369 """
370 return B.Float