Coverage for geometric_kernels/spaces/su.py: 95%
115 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2This module provides the :class:`SpecialUnitary` space and the respective
3:class:`~.eigenfunctions.Eigenfunctions` subclass :class:`SUEigenfunctions`.
4"""
6import itertools
7import json
8import math
9import operator
10from functools import reduce
12import lab as B
13import numpy as np
14from beartype.typing import List, Tuple
16from geometric_kernels.lab_extras import (
17 complex_conj,
18 complex_like,
19 create_complex,
20 dtype_double,
21 from_numpy,
22 qr,
23)
24from geometric_kernels.spaces.eigenfunctions import Eigenfunctions
25from geometric_kernels.spaces.lie_groups import (
26 CompactMatrixLieGroup,
27 LieGroupCharacter,
28 WeylAdditionTheorem,
29)
30from geometric_kernels.utils.utils import chain, get_resource_file_path
33class SUEigenfunctions(WeylAdditionTheorem):
34 def __init__(self, n, num_levels, compute_characters=True):
35 self.n = n
36 self.dim = n**2 - 1
37 self.rank = n - 1
39 self.rho = np.arange(self.n - 1, -self.n, -2) * 0.5
41 super().__init__(n, num_levels, compute_characters)
43 def _generate_signatures(self, num_levels: int) -> List[Tuple[int, ...]]:
44 sign_vals_lim = 100 if self.n in (1, 2) else 30 if self.n == 3 else 10
45 signatures = list(
46 itertools.combinations_with_replacement(
47 range(sign_vals_lim, -1, -1), r=self.rank
48 )
49 )
50 signatures = [sgn + (0,) for sgn in signatures]
52 eig_and_signature = [
53 (
54 round(4 * (len(signature) ** 2) * self._compute_eigenvalue(signature)),
55 signature,
56 )
57 for signature in signatures
58 ]
60 eig_and_signature.sort()
61 signatures = [eig_sgn[1] for eig_sgn in eig_and_signature][:num_levels]
62 return signatures
64 def _compute_dimension(self, signature: Tuple[int, ...]) -> int:
65 rep_dim = reduce(
66 operator.mul,
67 (
68 reduce(
69 operator.mul,
70 (
71 signature[i - 1] - signature[j - 1] + j - i
72 for j in range(i + 1, self.n + 1)
73 ),
74 )
75 / math.factorial(self.n - i)
76 for i in range(1, self.n)
77 ),
78 )
79 return int(round(rep_dim))
81 def _compute_eigenvalue(self, signature: Tuple[int, ...]) -> B.Float:
82 normalized_signature = np.asarray(signature, dtype=np.float64) - np.mean(
83 signature
84 )
85 lb_eigenvalue = (
86 np.linalg.norm(self.rho + normalized_signature) ** 2
87 - np.linalg.norm(self.rho) ** 2
88 )
89 return lb_eigenvalue
91 def _compute_character(
92 self, n: int, signature: Tuple[int, ...]
93 ) -> LieGroupCharacter:
94 return SUCharacter(n, signature)
96 def _torus_representative(self, X):
97 return B.eig(X, False)
99 def inverse(self, X: B.Numeric) -> B.Numeric:
100 return SpecialUnitary.inverse(X)
103class SUCharacter(LieGroupCharacter):
104 """
105 The class that represents a character of the SU(n) group.
107 Many of the characters on SU(n) are complex-valued.
109 These are polynomials whose coefficients are precomputed and stored in a
110 file. By default, there are 20 precomputed characters for n from 2 to 6.
111 If you want more, use the `compute_characters.py` script.
113 :param n:
114 The order n of the SO(n) group.
115 :param signature:
116 The signature that determines a particular character (and an
117 irreducible unitary representation along with it).
118 """
120 def __init__(self, n: int, signature: Tuple[int, ...]):
121 self.signature = signature
122 self.n = n
123 self.coeffs, self.monoms = self._load()
125 def _load(self):
126 group_name = "SU({})".format(self.n)
127 with get_resource_file_path("precomputed_characters.json") as file_path:
128 with file_path.open("r") as file:
129 character_formulas = json.load(file)
130 try:
131 cs, ms = character_formulas[group_name][str(self.signature)]
132 coeffs, monoms = (np.array(data) for data in (cs, ms))
133 return coeffs, monoms
134 except KeyError as e:
135 raise KeyError(
136 "Unable to retrieve character parameters for signature {} of {}, "
137 "perhaps it is not precomputed."
138 "Run compute_characters.py with changed parameters.".format(
139 e.args[0], group_name
140 )
141 ) from None
143 def __call__(self, gammas: B.Numeric) -> B.Numeric:
144 char_val = B.zeros(B.dtype(gammas), *gammas.shape[:-1])
145 for coeff, monom in zip(self.coeffs, self.monoms):
146 char_val += coeff * B.prod(
147 B.power(
148 gammas, B.cast(complex_like(gammas), from_numpy(gammas, monom))
149 ),
150 axis=-1,
151 )
152 return char_val
155class SpecialUnitary(CompactMatrixLieGroup):
156 r"""
157 The GeometricKernels space representing the special unitary group SU(n)
158 consisting of n by n complex unitary matrices with unit determinant.
160 The elements of this space are represented as n x n unitary
161 matrices with complex entries and unit determinant.
163 .. note::
164 A tutorial on how to use this space is available in the
165 :doc:`SpecialUnitary.ipynb </examples/SpecialUnitary>` notebook.
167 :param n:
168 The order n of the group SU(n).
170 .. note::
171 We only support n >= 2. Mathematically, SU(1) is trivial, consisting
172 of a single element (the identity), chances are you do not need it.
173 For large values of n, you might need to run the `compute_characters.py`
174 script to precompute the necessary mathematical quantities beyond the
175 ones provided by default.
177 .. admonition:: Citation
179 If you use this GeometricKernels space in your research, please consider
180 citing :cite:t:`azangulov2024a`.
181 """
183 def __init__(self, n: int):
184 if n < 2:
185 raise ValueError(f"Only n >= 2 is supported. n = {n} was provided.")
186 self.n = n
187 self.dim = n**2 - 1
188 self.rank = n - 1
189 super().__init__()
191 def __str__(self):
192 return f"SpecialUnitary({self.n})"
194 @property
195 def dimension(self) -> int:
196 """
197 The dimension of the space, as that of a Riemannian manifold.
199 :return:
200 floor(n^2-1) where n is the order of the group SU(n).
201 """
202 return self.dim
204 @staticmethod
205 def inverse(X: B.Numeric) -> B.Numeric:
206 return complex_conj(
207 B.transpose(X)
208 ) # B.transpose only inverses the last two dims.
210 def get_eigenfunctions(self, num: int) -> Eigenfunctions:
211 """
212 Returns the :class:`~.SUEigenfunctions` object with `num` levels
213 and order n.
215 :param num:
216 Number of levels.
217 """
218 return SUEigenfunctions(self.n, num)
220 def get_eigenvalues(self, num: int) -> B.Numeric:
221 eigenfunctions = SUEigenfunctions(self.n, num)
222 eigenvalues = np.array(
223 [eigenvalue for eigenvalue in eigenfunctions._eigenvalues]
224 )
225 return B.reshape(eigenvalues, -1, 1) # [num, 1]
227 def get_repeated_eigenvalues(self, num: int) -> B.Numeric:
228 eigenfunctions = SUEigenfunctions(self.n, num)
229 eigenvalues = chain(
230 eigenfunctions._eigenvalues,
231 [rep_dim**2 for rep_dim in eigenfunctions._dimensions],
232 )
233 return B.reshape(eigenvalues, -1, 1) # [M, 1]
235 def random(self, key: B.RandomState, number: int):
236 if self.n == 2:
237 # explicit parametrization via the double cover SU(2) = S_3
238 key, sphere_point = B.random.randn(key, dtype_double(key), number, 4)
239 sphere_point /= B.reshape(
240 B.sqrt(B.einsum("ij,ij->i", sphere_point, sphere_point)), -1, 1
241 )
242 a = create_complex(sphere_point[..., 0], sphere_point[..., 1])
243 b = create_complex(sphere_point[..., 2], sphere_point[..., 3])
244 r1 = B.stack(a, -complex_conj(b), axis=-1)
245 r2 = B.stack(b, complex_conj(a), axis=-1)
246 q = B.stack(r1, r2, axis=-1)
247 return key, q
248 else:
249 key, real = B.random.randn(key, dtype_double(key), number, self.n, self.n)
250 key, imag = B.random.randn(key, dtype_double(key), number, self.n, self.n)
251 h = create_complex(real, imag) / B.sqrt(2)
252 q, r = qr(h, mode="complete")
253 r_diag = B.einsum("...ii->...i", r)
254 r_diag_inv_phase = complex_conj(
255 r_diag / B.cast(B.dtype(r_diag), B.abs(r_diag))
256 )
257 q *= r_diag_inv_phase[:, None]
258 q_det = B.det(q)
259 q_det_inv_phase = complex_conj(
260 (q_det / B.cast(B.dtype(q_det), B.abs(q_det)))
261 )
262 q_new = q[:, :, 0] * q_det_inv_phase[:, None]
263 q_new = B.concat(q_new[:, :, None], q[:, :, 1:], axis=-1)
264 return key, q_new
266 @property
267 def element_shape(self):
268 """
269 :return:
270 [n, n].
271 """
272 return [self.n, self.n]
274 @property
275 def element_dtype(self):
276 """
277 :return:
278 B.Complex.
279 """
280 return B.Complex