Coverage for tests/spaces/test_eigenfunctions.py: 100%

80 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import jax 

2import lab as B 

3import numpy as np 

4import pytest 

5from opt_einsum import contract as einsum 

6 

7from geometric_kernels.kernels.matern_kernel import default_num 

8from geometric_kernels.spaces import CompactMatrixLieGroup, Hypersphere, Mesh 

9from geometric_kernels.utils.utils import chain 

10 

11from ..helper import check_function_with_backend, discrete_spectrum_spaces 

12 

13jax.config.update("jax_enable_x64", True) # enable float64 in JAX 

14 

15 

16@pytest.fixture( 

17 params=discrete_spectrum_spaces(), 

18 ids=str, 

19) 

20def inputs(request): 

21 """ 

22 Returns a tuple (space, eigenfunctions, X, X2, weights) where: 

23 - space = request.param, 

24 - eigenfunctions = space.get_eigenfunctions(num_levels), with reasonable num_levels 

25 - X is a random sample of random size from the space, 

26 - X2 is another random sample of random size from the space, 

27 - weights is an array of positive numbers of shape (eigenfunctions.num_levels, 1). 

28 """ 

29 space = request.param 

30 num_levels = default_num(space) 

31 if isinstance(space, Hypersphere): 

32 # For Hypersphere, the maximal number of levels with eigenfunction 

33 # evaluation support is stored in the num_computed_levels field. We 

34 # do not use more levels to be able to test eigenfunction evaluation. 

35 num_levels = min( 

36 10, num_levels, space.get_eigenfunctions(num_levels).num_computed_levels 

37 ) 

38 elif isinstance(space, Mesh): 

39 # We limit the number of levels to 50 to avoid excessive computation 

40 # and increased numerical error for higher order eigenfunctions. 

41 num_levels = min(50, num_levels) 

42 eigenfunctions = space.get_eigenfunctions(num_levels) 

43 

44 key = np.random.RandomState(0) 

45 N, N2 = key.randint(low=1, high=100 + 1, size=2) 

46 key, X = space.random(key, N) 

47 key, X2 = space.random(key, N2) 

48 

49 # These weights are used for testing the weighted outerproduct, they 

50 # should be positive. 

51 weights = np.random.rand(eigenfunctions.num_levels, 1) ** 2 + 1e-5 

52 

53 return space, eigenfunctions, X, X2, weights 

54 

55 

56@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

57def test_call_eigenfunctions(inputs, backend): 

58 space, eigenfunctions, X, _, _ = inputs 

59 

60 if isinstance(space, CompactMatrixLieGroup): 

61 pytest.skip( 

62 "CompactMatrixLieGroup subclasses do not currently support eigenfunction evaluation" 

63 ) 

64 

65 # Check that the eigenfunctions can be called, returning the right type and shape. 

66 check_function_with_backend( 

67 backend, 

68 (X.shape[0], eigenfunctions.num_eigenfunctions), 

69 eigenfunctions, 

70 X, 

71 compare_to_result=lambda res, f_out: f_out.shape == res, 

72 ) 

73 

74 

75def test_numbers_of_eigenfunctions(inputs): 

76 _, eigenfunctions, _, _, _ = inputs 

77 num_levels = eigenfunctions.num_levels 

78 # Check that the length of the `num_eigenfunctions_per_level` list is correct. 

79 assert len(eigenfunctions.num_eigenfunctions_per_level) == num_levels 

80 # Check that the first eigenspace is 1-dimensional. 

81 assert eigenfunctions.num_eigenfunctions_per_level[0] == 1 

82 

83 # Check that dimensions of eigenspaces are always positive. 

84 for i in range(num_levels): 

85 assert eigenfunctions.num_eigenfunctions_per_level[i] > 0 

86 

87 num_eigenfunctions_manual = sum(eigenfunctions.num_eigenfunctions_per_level) 

88 # Check that `num_eigenfunctions_per_level` sum up to the total number of 

89 # eigenfunctions. 

90 assert num_eigenfunctions_manual == eigenfunctions.num_eigenfunctions 

91 

92 

93@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

94def test_orthonormality(inputs, backend): 

95 space, eigenfunctions, _, _, _ = inputs 

96 

97 if isinstance(space, CompactMatrixLieGroup): 

98 pytest.skip( 

99 "CompactMatrixLieGroup subclasses do not currently support eigenfunction evaluation" 

100 ) 

101 

102 key = np.random.RandomState(0) 

103 key, xs = space.random(key, 10000) 

104 

105 # Check that the eigenfunctions are orthonormal by comparing a Monte Carlo 

106 # approximation of the inner product with the identity matrix. 

107 check_function_with_backend( 

108 backend, 

109 np.eye(eigenfunctions.num_eigenfunctions), 

110 lambda xs: B.matmul(B.T(eigenfunctions(xs)), eigenfunctions(xs)) / xs.shape[0], 

111 xs, 

112 atol=0.4, # very loose, but helps make sure the diagonal is close to 1 while the rest is close to 0 

113 ) 

114 

115 

116@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

117def test_weighted_outerproduct_with_addition_theorem(inputs, backend): 

118 space, eigenfunctions, X, X2, weights = inputs 

119 

120 if isinstance(space, CompactMatrixLieGroup): 

121 pytest.skip( 

122 "CompactMatrixLieGroup subclasses do not currently support eigenfunction evaluation" 

123 ) 

124 

125 chained_weights = chain( 

126 weights.squeeze(), eigenfunctions.num_eigenfunctions_per_level 

127 ) 

128 

129 Phi_X = eigenfunctions(X) 

130 Phi_X2 = eigenfunctions(X2) 

131 result = einsum("ni,ki,i->nk", Phi_X, Phi_X2, chained_weights) 

132 

133 # Check that `weighted_outerproduct`, which is based on the addition theorem, 

134 # returns the same result as the direct computation involving individual 

135 # eigenfunctions. 

136 check_function_with_backend( 

137 backend, result, eigenfunctions.weighted_outerproduct, weights, X, X2, atol=1e-2 

138 ) 

139 

140 

141@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

142def test_weighted_outerproduct_with_addition_theorem_one_input(inputs, backend): 

143 _, eigenfunctions, X, _, weights = inputs 

144 

145 result = eigenfunctions.weighted_outerproduct(weights, X, X) 

146 

147 # Check that `weighted_outerproduct`, when given only X (but not X2), 

148 # uses X2=X. 

149 check_function_with_backend( 

150 backend, 

151 result, 

152 eigenfunctions.weighted_outerproduct, 

153 weights, 

154 X, 

155 atol=1e-2, 

156 ) 

157 

158 

159@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

160def test_weighted_outerproduct_diag(inputs, backend): 

161 _, eigenfunctions, X, _, weights = inputs 

162 

163 result = np.diag(eigenfunctions.weighted_outerproduct(weights, X, X)) 

164 

165 # Check that `weighted_outerproduct_diag` returns the same result as the 

166 # diagonal of the full `weighted_outerproduct`. 

167 check_function_with_backend( 

168 backend, 

169 result, 

170 eigenfunctions.weighted_outerproduct_diag, 

171 weights, 

172 X, 

173 ) 

174 

175 

176@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

177def test_weighted_outerproduct_against_phi_product(inputs, backend): 

178 _, eigenfunctions, X, X2, weights = inputs 

179 

180 sum_phi_phi_for_level = eigenfunctions.phi_product(X, X2) 

181 

182 result = np.einsum("id,...nki->...nk", weights, sum_phi_phi_for_level) 

183 

184 # Check that `weighted_outerproduct` returns the weighted sum of `phi_product`. 

185 check_function_with_backend( 

186 backend, result, eigenfunctions.weighted_outerproduct, weights, X, X2 

187 ) 

188 

189 

190@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

191def test_weighted_outerproduct_diag_against_phi_product(inputs, backend): 

192 _, eigenfunctions, X, _, weights = inputs 

193 

194 phi_product_diag = eigenfunctions.phi_product_diag(X) 

195 

196 result = np.einsum("id,ni->n", weights, phi_product_diag) # [N,] 

197 

198 # Check that `weighted_outerproduct_diag` returns the weighted sum 

199 # of `phi_product_diag`. 

200 check_function_with_backend( 

201 backend, result, eigenfunctions.weighted_outerproduct_diag, weights, X 

202 )