Coverage for geometric_kernels/spaces/su.py: 95%

115 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`SpecialUnitary` space and the respective 

3:class:`~.eigenfunctions.Eigenfunctions` subclass :class:`SUEigenfunctions`. 

4""" 

5 

6import itertools 

7import json 

8import math 

9import operator 

10from functools import reduce 

11 

12import lab as B 

13import numpy as np 

14from beartype.typing import List, Tuple 

15 

16from geometric_kernels.lab_extras import ( 

17 complex_conj, 

18 complex_like, 

19 create_complex, 

20 dtype_double, 

21 from_numpy, 

22 qr, 

23) 

24from geometric_kernels.spaces.eigenfunctions import Eigenfunctions 

25from geometric_kernels.spaces.lie_groups import ( 

26 CompactMatrixLieGroup, 

27 LieGroupCharacter, 

28 WeylAdditionTheorem, 

29) 

30from geometric_kernels.utils.utils import chain, get_resource_file_path 

31 

32 

33class SUEigenfunctions(WeylAdditionTheorem): 

34 def __init__(self, n, num_levels, compute_characters=True): 

35 self.n = n 

36 self.dim = n**2 - 1 

37 self.rank = n - 1 

38 

39 self.rho = np.arange(self.n - 1, -self.n, -2) * 0.5 

40 

41 super().__init__(n, num_levels, compute_characters) 

42 

43 def _generate_signatures(self, num_levels: int) -> List[Tuple[int, ...]]: 

44 sign_vals_lim = 100 if self.n in (1, 2) else 30 if self.n == 3 else 10 

45 signatures = list( 

46 itertools.combinations_with_replacement( 

47 range(sign_vals_lim, -1, -1), r=self.rank 

48 ) 

49 ) 

50 signatures = [sgn + (0,) for sgn in signatures] 

51 

52 eig_and_signature = [ 

53 ( 

54 round(4 * (len(signature) ** 2) * self._compute_eigenvalue(signature)), 

55 signature, 

56 ) 

57 for signature in signatures 

58 ] 

59 

60 eig_and_signature.sort() 

61 signatures = [eig_sgn[1] for eig_sgn in eig_and_signature][:num_levels] 

62 return signatures 

63 

64 def _compute_dimension(self, signature: Tuple[int, ...]) -> int: 

65 rep_dim = reduce( 

66 operator.mul, 

67 ( 

68 reduce( 

69 operator.mul, 

70 ( 

71 signature[i - 1] - signature[j - 1] + j - i 

72 for j in range(i + 1, self.n + 1) 

73 ), 

74 ) 

75 / math.factorial(self.n - i) 

76 for i in range(1, self.n) 

77 ), 

78 ) 

79 return int(round(rep_dim)) 

80 

81 def _compute_eigenvalue(self, signature: Tuple[int, ...]) -> B.Float: 

82 normalized_signature = np.asarray(signature, dtype=np.float64) - np.mean( 

83 signature 

84 ) 

85 lb_eigenvalue = ( 

86 np.linalg.norm(self.rho + normalized_signature) ** 2 

87 - np.linalg.norm(self.rho) ** 2 

88 ) 

89 return lb_eigenvalue 

90 

91 def _compute_character( 

92 self, n: int, signature: Tuple[int, ...] 

93 ) -> LieGroupCharacter: 

94 return SUCharacter(n, signature) 

95 

96 def _torus_representative(self, X): 

97 return B.eig(X, False) 

98 

99 def inverse(self, X: B.Numeric) -> B.Numeric: 

100 return SpecialUnitary.inverse(X) 

101 

102 

103class SUCharacter(LieGroupCharacter): 

104 """ 

105 The class that represents a character of the SU(n) group. 

106 

107 Many of the characters on SU(n) are complex-valued. 

108 

109 These are polynomials whose coefficients are precomputed and stored in a 

110 file. By default, there are 20 precomputed characters for n from 2 to 6. 

111 If you want more, use the `compute_characters.py` script. 

112 

113 :param n: 

114 The order n of the SO(n) group. 

115 :param signature: 

116 The signature that determines a particular character (and an 

117 irreducible unitary representation along with it). 

118 """ 

119 

120 def __init__(self, n: int, signature: Tuple[int, ...]): 

121 self.signature = signature 

122 self.n = n 

123 self.coeffs, self.monoms = self._load() 

124 

125 def _load(self): 

126 group_name = "SU({})".format(self.n) 

127 with get_resource_file_path("precomputed_characters.json") as file_path: 

128 with file_path.open("r") as file: 

129 character_formulas = json.load(file) 

130 try: 

131 cs, ms = character_formulas[group_name][str(self.signature)] 

132 coeffs, monoms = (np.array(data) for data in (cs, ms)) 

133 return coeffs, monoms 

134 except KeyError as e: 

135 raise KeyError( 

136 "Unable to retrieve character parameters for signature {} of {}, " 

137 "perhaps it is not precomputed." 

138 "Run compute_characters.py with changed parameters.".format( 

139 e.args[0], group_name 

140 ) 

141 ) from None 

142 

143 def __call__(self, gammas: B.Numeric) -> B.Numeric: 

144 char_val = B.zeros(B.dtype(gammas), *gammas.shape[:-1]) 

145 for coeff, monom in zip(self.coeffs, self.monoms): 

146 char_val += coeff * B.prod( 

147 B.power( 

148 gammas, B.cast(complex_like(gammas), from_numpy(gammas, monom)) 

149 ), 

150 axis=-1, 

151 ) 

152 return char_val 

153 

154 

155class SpecialUnitary(CompactMatrixLieGroup): 

156 r""" 

157 The GeometricKernels space representing the special unitary group SU(n) 

158 consisting of n by n complex unitary matrices with unit determinant. 

159 

160 The elements of this space are represented as n x n unitary 

161 matrices with complex entries and unit determinant. 

162 

163 .. note:: 

164 A tutorial on how to use this space is available in the 

165 :doc:`SpecialUnitary.ipynb </examples/SpecialUnitary>` notebook. 

166 

167 :param n: 

168 The order n of the group SU(n). 

169 

170 .. note:: 

171 We only support n >= 2. Mathematically, SU(1) is trivial, consisting 

172 of a single element (the identity), chances are you do not need it. 

173 For large values of n, you might need to run the `compute_characters.py` 

174 script to precompute the necessary mathematical quantities beyond the 

175 ones provided by default. 

176 

177 .. admonition:: Citation 

178 

179 If you use this GeometricKernels space in your research, please consider 

180 citing :cite:t:`azangulov2024a`. 

181 """ 

182 

183 def __init__(self, n: int): 

184 if n < 2: 

185 raise ValueError(f"Only n >= 2 is supported. n = {n} was provided.") 

186 self.n = n 

187 self.dim = n**2 - 1 

188 self.rank = n - 1 

189 super().__init__() 

190 

191 def __str__(self): 

192 return f"SpecialUnitary({self.n})" 

193 

194 @property 

195 def dimension(self) -> int: 

196 """ 

197 The dimension of the space, as that of a Riemannian manifold. 

198 

199 :return: 

200 floor(n^2-1) where n is the order of the group SU(n). 

201 """ 

202 return self.dim 

203 

204 @staticmethod 

205 def inverse(X: B.Numeric) -> B.Numeric: 

206 return complex_conj( 

207 B.transpose(X) 

208 ) # B.transpose only inverses the last two dims. 

209 

210 def get_eigenfunctions(self, num: int) -> Eigenfunctions: 

211 """ 

212 Returns the :class:`~.SUEigenfunctions` object with `num` levels 

213 and order n. 

214 

215 :param num: 

216 Number of levels. 

217 """ 

218 return SUEigenfunctions(self.n, num) 

219 

220 def get_eigenvalues(self, num: int) -> B.Numeric: 

221 eigenfunctions = SUEigenfunctions(self.n, num) 

222 eigenvalues = np.array( 

223 [eigenvalue for eigenvalue in eigenfunctions._eigenvalues] 

224 ) 

225 return B.reshape(eigenvalues, -1, 1) # [num, 1] 

226 

227 def get_repeated_eigenvalues(self, num: int) -> B.Numeric: 

228 eigenfunctions = SUEigenfunctions(self.n, num) 

229 eigenvalues = chain( 

230 eigenfunctions._eigenvalues, 

231 [rep_dim**2 for rep_dim in eigenfunctions._dimensions], 

232 ) 

233 return B.reshape(eigenvalues, -1, 1) # [M, 1] 

234 

235 def random(self, key: B.RandomState, number: int): 

236 if self.n == 2: 

237 # explicit parametrization via the double cover SU(2) = S_3 

238 key, sphere_point = B.random.randn(key, dtype_double(key), number, 4) 

239 sphere_point /= B.reshape( 

240 B.sqrt(B.einsum("ij,ij->i", sphere_point, sphere_point)), -1, 1 

241 ) 

242 a = create_complex(sphere_point[..., 0], sphere_point[..., 1]) 

243 b = create_complex(sphere_point[..., 2], sphere_point[..., 3]) 

244 r1 = B.stack(a, -complex_conj(b), axis=-1) 

245 r2 = B.stack(b, complex_conj(a), axis=-1) 

246 q = B.stack(r1, r2, axis=-1) 

247 return key, q 

248 else: 

249 key, real = B.random.randn(key, dtype_double(key), number, self.n, self.n) 

250 key, imag = B.random.randn(key, dtype_double(key), number, self.n, self.n) 

251 h = create_complex(real, imag) / B.sqrt(2) 

252 q, r = qr(h, mode="complete") 

253 r_diag = B.einsum("...ii->...i", r) 

254 r_diag_inv_phase = complex_conj( 

255 r_diag / B.cast(B.dtype(r_diag), B.abs(r_diag)) 

256 ) 

257 q *= r_diag_inv_phase[:, None] 

258 q_det = B.det(q) 

259 q_det_inv_phase = complex_conj( 

260 (q_det / B.cast(B.dtype(q_det), B.abs(q_det))) 

261 ) 

262 q_new = q[:, :, 0] * q_det_inv_phase[:, None] 

263 q_new = B.concat(q_new[:, :, None], q[:, :, 1:], axis=-1) 

264 return key, q_new 

265 

266 @property 

267 def element_shape(self): 

268 """ 

269 :return: 

270 [n, n]. 

271 """ 

272 return [self.n, self.n] 

273 

274 @property 

275 def element_dtype(self): 

276 """ 

277 :return: 

278 B.Complex. 

279 """ 

280 return B.Complex