"""
This module provides the :class:`SpecialOrthogonal` space, the respective
:class:`~.eigenfunctions.Eigenfunctions` subclass :class:`SOEigenfunctions`,
and a class :class:`SOCharacter` for representing characters of the group.
"""
import itertools
import json
import math
import operator
from functools import reduce
import lab as B
import numpy as np
from beartype.typing import List, Tuple
from geometric_kernels.lab_extras import (
complex_conj,
complex_like,
create_complex,
dtype_double,
from_numpy,
qr,
take_along_axis,
)
from geometric_kernels.spaces.eigenfunctions import Eigenfunctions
from geometric_kernels.spaces.lie_groups import (
CompactMatrixLieGroup,
LieGroupCharacter,
WeylAdditionTheorem,
)
from geometric_kernels.utils.utils import (
chain,
fixed_length_partitions,
get_resource_file_path,
)
[docs]
class SOEigenfunctions(WeylAdditionTheorem):
def __init__(self, n, num_levels, compute_characters=True):
self.n = n
self.dim = n * (n - 1) // 2
self.rank = n // 2
if self.n % 2 == 0:
self.rho = np.arange(self.rank - 1, -1, -1)
else:
self.rho = np.arange(self.rank - 1, -1, -1) + 0.5
super().__init__(n, num_levels, compute_characters)
def _generate_signatures(self, num_levels: int) -> List[Tuple[int, ...]]:
signatures = []
# largest LB eigenvalues correspond to partitions of smallest integers
# IF p >> k the number of partitions of p into k parts is O(p^k)
if self.n == 3:
# in this case rank=1, so all partitions are trivial
# and LB eigenvalue of the signature corresponding to p is p
# 200 is more than enough, since the contribution of such term
# even in Matern(0.5) case is less than 200^{-2}
SIGNATURE_SUM = 200
else:
# Roughly speaking this is a bruteforce search through 50^{self.rank} smallest eigenvalues
SIGNATURE_SUM = 50
for signature_sum in range(0, SIGNATURE_SUM):
for i in range(0, self.rank + 1):
for signature in fixed_length_partitions(signature_sum, i):
signature.extend([0] * (self.rank - i))
signatures.append(tuple(signature))
if self.n % 2 == 0 and signature[-1] != 0:
signature[-1] = -signature[-1]
signatures.append(tuple(signature))
eig_and_signature = [
(round(4 * self._compute_eigenvalue(signature)), signature)
for signature in signatures
]
eig_and_signature.sort()
signatures = [eig_sgn[1] for eig_sgn in eig_and_signature][:num_levels]
return signatures
def _compute_dimension(self, signature: Tuple[int, ...]) -> int:
if self.n % 2 == 1:
qs = [pk + self.rank - k - 1 / 2 for k, pk in enumerate(signature)]
rep_dim = reduce(
operator.mul,
(2 * qs[k] / math.factorial(2 * k + 1) for k in range(0, self.rank)),
) * reduce(
operator.mul,
(
(qs[i] - qs[j]) * (qs[i] + qs[j])
for i, j in itertools.combinations(range(self.rank), 2)
),
1,
)
return int(round(rep_dim))
else:
qs = [
pk + self.rank - k - 1 if k != self.rank - 1 else abs(pk)
for k, pk in enumerate(signature)
]
rep_dim = int(
reduce(
operator.mul,
(2 / math.factorial(2 * k) for k in range(1, self.rank)),
)
* reduce(
operator.mul,
(
(qs[i] - qs[j]) * (qs[i] + qs[j])
for i, j in itertools.combinations(range(self.rank), 2)
),
1,
)
)
return int(round(rep_dim))
def _compute_eigenvalue(self, signature: Tuple[int, ...]) -> B.Float:
np_sgn = np.array(signature)
rho = self.rho
eigenvalue = np.linalg.norm(rho + np_sgn) ** 2 - np.linalg.norm(rho) ** 2
return eigenvalue
def _compute_character(
self, n: int, signature: Tuple[int, ...]
) -> LieGroupCharacter:
return SOCharacter(n, signature)
def _torus_representative(self, X: B.Numeric) -> B.Numeric:
gamma = None
if self.n == 3:
# In SO(3) the torus representative is determined by the non-trivial pair of eigenvalues,
# which can be calculated from the trace
trace = B.expand_dims(B.trace(X), axis=-1)
real = (trace - 1) / 2
zeros = real * 0
imag = B.sqrt(B.maximum(1 - real * real, zeros))
gamma = create_complex(real, imag)
elif self.n % 2 == 1:
# In SO(2n+1) the torus representative is determined by the (unordered) non-trivial eigenvalues
eigvals = B.eig(X, False)
sorted_ind = B.argsort(B.real(eigvals), axis=-1)
eigvals = take_along_axis(eigvals, sorted_ind, -1)
gamma = eigvals[..., 0:-1:2]
else:
# In SO(2n) each unordered set of eigenvalues determines two conjugacy classes
eigvals, eigvecs = B.eig(X)
sorted_ind = B.argsort(B.real(eigvals), axis=-1)
eigvals = take_along_axis(eigvals, sorted_ind, -1)
eigvecs = take_along_axis(
eigvecs,
B.broadcast_to(B.expand_dims(sorted_ind, axis=-2), *eigvecs.shape),
-1,
)
# c is a matrix transforming x into its canonical form (with 2x2 blocks)
c = B.reshape(
B.stack(
eigvecs[..., ::2] + eigvecs[..., 1::2],
eigvecs[..., ::2] - eigvecs[..., 1::2],
axis=-1,
),
*eigvecs.shape[:-1],
-1,
)
# eigenvectors calculated by LAPACK are either real or purely imaginary, make everything real
# WARNING: might depend on the implementation of the eigendecomposition!
c = B.real(c) + B.imag(c)
# normalize s.t. det(c)≈±1, probably unnecessary
c /= math.sqrt(2)
eigvals = B.concat(
B.expand_dims(
B.power(eigvals[..., 0], B.cast(complex_like(c), B.sign(B.det(c)))),
axis=-1,
),
eigvals[..., 1:],
axis=-1,
)
gamma = eigvals[..., ::2]
gamma = B.concat(gamma, complex_conj(gamma), axis=-1)
return gamma
[docs]
def inverse(self, X: B.Numeric) -> B.Numeric:
return SpecialOrthogonal.inverse(X)
[docs]
class SOCharacter(LieGroupCharacter):
"""
The class that represents a character of the SO(n) group.
These characters are always real-valued.
These are polynomials whose coefficients are precomputed and stored in a
file. By default, there are 20 precomputed characters for n from 3 to 8.
If you want more, use the `compute_characters.py` script.
:param n:
The order n of the SO(n) group.
:param signature:
The signature that determines a particular character (and an
irreducible unitary representation along with it).
"""
def __init__(self, n: int, signature: Tuple[int, ...]):
self.signature = signature
self.n = n
self.coeffs, self.monoms = self._load()
def _load(self):
group_name = "SO({})".format(self.n)
with get_resource_file_path("precomputed_characters.json") as file_path:
with file_path.open("r") as file:
character_formulas = json.load(file)
try:
cs, ms = character_formulas[group_name][str(self.signature)]
coeffs, monoms = (np.array(data) for data in (cs, ms))
return coeffs, monoms
except KeyError as e:
raise KeyError(
"Unable to retrieve character parameters for signature {} of {}, "
"perhaps it is not precomputed."
"Run compute_characters.py with changed parameters.".format(
e.args[0], group_name
)
) from None
[docs]
def __call__(self, gammas: B.Numeric) -> B.Numeric:
char_val = B.zeros(B.dtype(gammas), *gammas.shape[:-1])
for coeff, monom in zip(self.coeffs, self.monoms):
char_val += coeff * B.prod(
gammas ** B.cast(B.dtype(gammas), from_numpy(gammas, monom)), axis=-1
)
return char_val
[docs]
class SpecialOrthogonal(CompactMatrixLieGroup):
r"""
The GeometricKernels space representing the special orthogonal group SO(n)
consisting of n by n orthogonal matrices with unit determinant.
The elements of this space are represented as n x n orthogonal
matrices with real entries and unit determinant.
.. note::
A tutorial on how to use this space is available in the
:doc:`SpecialOrthogonal.ipynb </examples/SpecialOrthogonal>` notebook.
:param n:
The order n of the group SO(n).
.. note::
We only support n >= 3. Mathematically, SO(2) is equivalent to the
unit circle, which is available as the :class:`~.spaces.Circle` space.
For larger values of n, you might need to run the
`utils/compute_characters.py` script to precompute the necessary
mathematical quantities beyond the ones provided by default. Same
can be required for larger numbers of levels.
.. admonition:: Citation
If you use this GeometricKernels space in your research, please consider
citing :cite:t:`azangulov2024a`.
"""
def __init__(self, n: int):
if n < 3:
raise ValueError("Only n >= 3 is supported. For n = 2, use Circle.")
self.n = n
self.dim = n * (n - 1) // 2
self.rank = n // 2
super().__init__()
def __str__(self):
return f"SpecialOrthogonal({self.n})"
@property
def dimension(self) -> int:
"""
The dimension of the space, as that of a Riemannian manifold.
:return:
floor(n(n-1)/2) where n is the order of the group SO(n).
"""
return self.dim
[docs]
@staticmethod
def inverse(X: B.Numeric) -> B.Numeric:
return B.transpose(X) # B.transpose only inverses the last two dims.
[docs]
def get_eigenfunctions(self, num: int) -> Eigenfunctions:
"""
Returns the :class:`~.SOEigenfunctions` object with `num` levels
and order n.
:param num:
Number of levels.
"""
return SOEigenfunctions(self.n, num)
[docs]
def get_eigenvalues(self, num: int) -> B.Numeric:
eigenfunctions = SOEigenfunctions(self.n, num)
eigenvalues = np.array(
[eigenvalue for eigenvalue in eigenfunctions._eigenvalues]
)
return B.reshape(eigenvalues, -1, 1) # [num, 1]
[docs]
def get_repeated_eigenvalues(self, num: int) -> B.Numeric:
eigenfunctions = SOEigenfunctions(self.n, num)
eigenvalues = chain(
eigenfunctions._eigenvalues,
[rep_dim**2 for rep_dim in eigenfunctions._dimensions],
)
return B.reshape(eigenvalues, -1, 1) # [J, 1]
[docs]
def random(self, key: B.RandomState, number: int):
if self.n == 2: # for the bright future where we support SO(2).
# SO(2) = S^1
key, thetas = B.random.randn(key, dtype_double(key), number, 1)
thetas = 2 * math.pi * thetas
c = B.cos(thetas)
s = B.sin(thetas)
r1 = B.stack(c, s, axis=-1)
r2 = B.stack(-s, c, axis=-1)
q = B.concat(r1, r2, axis=-2)
return key, q
elif self.n == 3:
# explicit parametrization via the double cover SU(2) = S^3
key, sphere_point = B.random.randn(key, dtype_double(key), number, 4)
sphere_point /= B.reshape(
B.sqrt(B.einsum("ij,ij->i", sphere_point, sphere_point)), -1, 1
)
x, y, z, w = (B.reshape(sphere_point[..., i], -1, 1) for i in range(4))
xx, yy, zz = x**2, y**2, z**2
xy, xz, xw, yz, yw, zw = x * y, x * z, x * w, y * z, y * w, z * w
r1 = B.stack(1 - 2 * (yy + zz), 2 * (xy - zw), 2 * (xz + yw), axis=1)
r2 = B.stack(2 * (xy + zw), 1 - 2 * (xx + zz), 2 * (yz - xw), axis=1)
r3 = B.stack(2 * (xz - yw), 2 * (yz + xw), 1 - 2 * (xx + yy), axis=1)
q = B.concat(r1, r2, r3, axis=-1)
return key, q
else:
# qr decomposition is not in the lab package, so numpy is used.
key, h = B.random.randn(key, dtype_double(key), number, self.n, self.n)
q, r = qr(h, mode="complete")
r_diag_sign = B.sign(B.einsum("...ii->...i", r))
q *= r_diag_sign[:, None]
q_det_sign = B.sign(B.det(q))
q_new = q[:, :, 0] * q_det_sign[:, None]
q_new = B.concat(q_new[:, :, None], q[:, :, 1:], axis=-1)
return key, q_new
@property
def element_shape(self):
"""
:return:
[n, n].
"""
return [self.n, self.n]
@property
def element_dtype(self):
"""
:return:
B.Float.
"""
return B.Float