Coverage for tests/spaces/test_lie_groups.py: 98%
59 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import lab as B
2import numpy as np
3import pytest
5from geometric_kernels.kernels.matern_kernel import default_num
6from geometric_kernels.lab_extras import complex_conj
7from geometric_kernels.spaces import SpecialOrthogonal, SpecialUnitary
9from ..helper import check_function_with_backend, compact_matrix_lie_groups
12@pytest.fixture(
13 params=compact_matrix_lie_groups(),
14 ids=str,
15)
16def inputs(request):
17 """
18 Returns a tuple (space, eigenfunctions, X, X2) where:
19 - space = request.param,
20 - eigenfunctions = space.get_eigenfunctions(num_levels), with reasonable num_levels
21 - X is a random sample of random size from the space,
22 - X2 is another random sample of random size from the space,
23 """
24 space = request.param
25 num_levels = min(10, default_num(space))
26 eigenfunctions = space.get_eigenfunctions(num_levels)
28 key = np.random.RandomState(0)
29 N, N2 = key.randint(low=1, high=100 + 1, size=2)
30 key, X = space.random(key, N)
31 key, X2 = space.random(key, N2)
33 return space, eigenfunctions, X, X2
36def get_dtype(group):
37 if isinstance(group, SpecialOrthogonal):
38 return np.double
39 elif isinstance(group, SpecialUnitary):
40 return np.cdouble
41 else:
42 raise ValueError()
45@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
46def test_group_inverse(inputs, backend):
47 group, _, X, _ = inputs
49 result = np.eye(group.n, dtype=get_dtype(group))
50 result = np.broadcast_to(result, (X.shape[0], group.n, group.n))
52 check_function_with_backend(
53 backend,
54 result,
55 lambda X: B.matmul(X, group.inverse(X)),
56 X,
57 )
60@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
61def test_character_conj_invariant(inputs, backend):
62 group, eigenfunctions, X, G = inputs
64 # Truncate X and G to have the same length
65 n_xs = min(X.shape[0], G.shape[0])
66 X = X[:n_xs, :, :]
67 G = G[:n_xs, :, :]
69 def gammas_diff(X, G, chi):
70 conjugates = B.matmul(B.matmul(G, X), group.inverse(G))
71 conj_gammas = eigenfunctions._torus_representative(conjugates)
73 xs_gammas = eigenfunctions._torus_representative(X)
75 return chi(xs_gammas) - chi(conj_gammas)
77 for chi in eigenfunctions._characters:
78 check_function_with_backend(
79 backend,
80 np.zeros((n_xs,)),
81 lambda X, G: gammas_diff(X, G, chi),
82 X,
83 G,
84 atol=1e-3,
85 )
88@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
89def test_character_at_identity(inputs, backend):
90 group, eigenfunctions, _, _ = inputs
92 for chi, dim in zip(eigenfunctions._characters, eigenfunctions._dimensions):
93 check_function_with_backend(
94 backend,
95 np.array([dim], dtype=get_dtype(group)),
96 lambda X: B.real(chi(eigenfunctions._torus_representative(X))),
97 np.eye(group.n, dtype=get_dtype(group))[None, ...],
98 )
100 check_function_with_backend(
101 backend,
102 np.array([0], dtype=get_dtype(group)),
103 lambda X: B.imag(chi(eigenfunctions._torus_representative(X))),
104 np.eye(group.n, dtype=get_dtype(group))[None, ...],
105 )
108@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
109def test_characters_orthogonal(inputs, backend):
110 group, eigenfunctions, _, _ = inputs
112 num_samples = 10000
113 key = np.random.RandomState(0)
114 _, X = group.random(key, num_samples)
116 def all_char_vals(X):
117 gammas = eigenfunctions._torus_representative(X)
118 values = [
119 chi(gammas)[..., None] # [num_samples, 1]
120 for chi in eigenfunctions._characters
121 ]
123 return B.concat(*values, axis=-1)
125 check_function_with_backend(
126 backend,
127 np.eye(eigenfunctions.num_levels, dtype=get_dtype(group)),
128 lambda X: complex_conj(B.T(all_char_vals(X))) @ all_char_vals(X) / num_samples,
129 X,
130 atol=0.4, # very loose, but helps make sure the diagonal is close to 1 while the rest is close to 0
131 )