Coverage for tests/spaces/test_lie_groups.py: 98%

59 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import lab as B 

2import numpy as np 

3import pytest 

4 

5from geometric_kernels.kernels.matern_kernel import default_num 

6from geometric_kernels.lab_extras import complex_conj 

7from geometric_kernels.spaces import SpecialOrthogonal, SpecialUnitary 

8 

9from ..helper import check_function_with_backend, compact_matrix_lie_groups 

10 

11 

12@pytest.fixture( 

13 params=compact_matrix_lie_groups(), 

14 ids=str, 

15) 

16def inputs(request): 

17 """ 

18 Returns a tuple (space, eigenfunctions, X, X2) where: 

19 - space = request.param, 

20 - eigenfunctions = space.get_eigenfunctions(num_levels), with reasonable num_levels 

21 - X is a random sample of random size from the space, 

22 - X2 is another random sample of random size from the space, 

23 """ 

24 space = request.param 

25 num_levels = min(10, default_num(space)) 

26 eigenfunctions = space.get_eigenfunctions(num_levels) 

27 

28 key = np.random.RandomState(0) 

29 N, N2 = key.randint(low=1, high=100 + 1, size=2) 

30 key, X = space.random(key, N) 

31 key, X2 = space.random(key, N2) 

32 

33 return space, eigenfunctions, X, X2 

34 

35 

36def get_dtype(group): 

37 if isinstance(group, SpecialOrthogonal): 

38 return np.double 

39 elif isinstance(group, SpecialUnitary): 

40 return np.cdouble 

41 else: 

42 raise ValueError() 

43 

44 

45@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

46def test_group_inverse(inputs, backend): 

47 group, _, X, _ = inputs 

48 

49 result = np.eye(group.n, dtype=get_dtype(group)) 

50 result = np.broadcast_to(result, (X.shape[0], group.n, group.n)) 

51 

52 check_function_with_backend( 

53 backend, 

54 result, 

55 lambda X: B.matmul(X, group.inverse(X)), 

56 X, 

57 ) 

58 

59 

60@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

61def test_character_conj_invariant(inputs, backend): 

62 group, eigenfunctions, X, G = inputs 

63 

64 # Truncate X and G to have the same length 

65 n_xs = min(X.shape[0], G.shape[0]) 

66 X = X[:n_xs, :, :] 

67 G = G[:n_xs, :, :] 

68 

69 def gammas_diff(X, G, chi): 

70 conjugates = B.matmul(B.matmul(G, X), group.inverse(G)) 

71 conj_gammas = eigenfunctions._torus_representative(conjugates) 

72 

73 xs_gammas = eigenfunctions._torus_representative(X) 

74 

75 return chi(xs_gammas) - chi(conj_gammas) 

76 

77 for chi in eigenfunctions._characters: 

78 check_function_with_backend( 

79 backend, 

80 np.zeros((n_xs,)), 

81 lambda X, G: gammas_diff(X, G, chi), 

82 X, 

83 G, 

84 atol=1e-3, 

85 ) 

86 

87 

88@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

89def test_character_at_identity(inputs, backend): 

90 group, eigenfunctions, _, _ = inputs 

91 

92 for chi, dim in zip(eigenfunctions._characters, eigenfunctions._dimensions): 

93 check_function_with_backend( 

94 backend, 

95 np.array([dim], dtype=get_dtype(group)), 

96 lambda X: B.real(chi(eigenfunctions._torus_representative(X))), 

97 np.eye(group.n, dtype=get_dtype(group))[None, ...], 

98 ) 

99 

100 check_function_with_backend( 

101 backend, 

102 np.array([0], dtype=get_dtype(group)), 

103 lambda X: B.imag(chi(eigenfunctions._torus_representative(X))), 

104 np.eye(group.n, dtype=get_dtype(group))[None, ...], 

105 ) 

106 

107 

108@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

109def test_characters_orthogonal(inputs, backend): 

110 group, eigenfunctions, _, _ = inputs 

111 

112 num_samples = 10000 

113 key = np.random.RandomState(0) 

114 _, X = group.random(key, num_samples) 

115 

116 def all_char_vals(X): 

117 gammas = eigenfunctions._torus_representative(X) 

118 values = [ 

119 chi(gammas)[..., None] # [num_samples, 1] 

120 for chi in eigenfunctions._characters 

121 ] 

122 

123 return B.concat(*values, axis=-1) 

124 

125 check_function_with_backend( 

126 backend, 

127 np.eye(eigenfunctions.num_levels, dtype=get_dtype(group)), 

128 lambda X: complex_conj(B.T(all_char_vals(X))) @ all_char_vals(X) / num_samples, 

129 X, 

130 atol=0.4, # very loose, but helps make sure the diagonal is close to 1 while the rest is close to 0 

131 )