Coverage for tests / spaces / test_eigenfunctions.py: 100%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-15 13:41 +0000

1import jax 

2import lab as B 

3import numpy as np 

4import pytest 

5from opt_einsum import contract as einsum 

6 

7from geometric_kernels.kernels.matern_kernel import default_num 

8from geometric_kernels.spaces import ( 

9 CompactMatrixLieGroup, 

10 HammingGraph, 

11 Hypersphere, 

12 Mesh, 

13) 

14from geometric_kernels.utils.utils import chain 

15 

16from ..helper import check_function_with_backend, discrete_spectrum_spaces 

17 

18jax.config.update("jax_enable_x64", True) # enable float64 in JAX 

19 

20 

21@pytest.fixture( 

22 params=discrete_spectrum_spaces(), 

23 ids=str, 

24) 

25def inputs(request): 

26 """ 

27 Returns a tuple (space, eigenfunctions, X, X2, weights) where: 

28 - space = request.param, 

29 - eigenfunctions = space.get_eigenfunctions(num_levels), with reasonable num_levels 

30 - X is a random sample of random size from the space, 

31 - X2 is another random sample of random size from the space, 

32 - weights is an array of positive numbers of shape (eigenfunctions.num_levels, 1). 

33 """ 

34 space = request.param 

35 num_levels = default_num(space) 

36 if isinstance(space, Hypersphere): 

37 # For Hypersphere, the maximal number of levels with eigenfunction 

38 # evaluation support is stored in the num_computed_levels field. We 

39 # do not use more levels to be able to test eigenfunction evaluation. 

40 num_levels = min( 

41 10, num_levels, space.get_eigenfunctions(num_levels).num_computed_levels 

42 ) 

43 elif isinstance(space, Mesh): 

44 # We limit the number of levels to 50 to avoid excessive computation 

45 # and increased numerical error for higher order eigenfunctions. 

46 num_levels = min(50, num_levels) 

47 eigenfunctions = space.get_eigenfunctions(num_levels) 

48 

49 key = np.random.RandomState(0) 

50 N, N2 = key.randint(low=1, high=100 + 1, size=2) 

51 key, X = space.random(key, N) 

52 key, X2 = space.random(key, N2) 

53 

54 # These weights are used for testing the weighted outerproduct, they 

55 # should be positive. 

56 weights = np.random.rand(eigenfunctions.num_levels, 1) ** 2 + 1e-5 

57 

58 return space, eigenfunctions, X, X2, weights 

59 

60 

61@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

62def test_call_eigenfunctions(inputs, backend): 

63 space, eigenfunctions, X, _, _ = inputs 

64 

65 if isinstance(space, (CompactMatrixLieGroup, HammingGraph)): 

66 space_name = type(space).__name__ 

67 pytest.skip(f"{space_name} does not currently support eigenfunction evaluation") 

68 

69 # Check that the eigenfunctions can be called, returning the right type and shape. 

70 check_function_with_backend( 

71 backend, 

72 (X.shape[0], eigenfunctions.num_eigenfunctions), 

73 eigenfunctions, 

74 X, 

75 compare_to_result=lambda res, f_out: f_out.shape == res, 

76 ) 

77 

78 

79def test_numbers_of_eigenfunctions(inputs): 

80 _, eigenfunctions, _, _, _ = inputs 

81 num_levels = eigenfunctions.num_levels 

82 # Check that the length of the `num_eigenfunctions_per_level` list is correct. 

83 assert len(eigenfunctions.num_eigenfunctions_per_level) == num_levels 

84 # Check that the first eigenspace is 1-dimensional. 

85 assert eigenfunctions.num_eigenfunctions_per_level[0] == 1 

86 

87 # Check that dimensions of eigenspaces are always positive. 

88 for i in range(num_levels): 

89 assert eigenfunctions.num_eigenfunctions_per_level[i] > 0 

90 

91 num_eigenfunctions_manual = sum(eigenfunctions.num_eigenfunctions_per_level) 

92 # Check that `num_eigenfunctions_per_level` sum up to the total number of 

93 # eigenfunctions. 

94 assert num_eigenfunctions_manual == eigenfunctions.num_eigenfunctions 

95 

96 

97@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

98def test_orthonormality(inputs, backend): 

99 space, eigenfunctions, _, _, _ = inputs 

100 

101 if isinstance(space, (CompactMatrixLieGroup, HammingGraph)): 

102 space_name = type(space).__name__ 

103 pytest.skip(f"{space_name} does not currently support eigenfunction evaluation") 

104 

105 key = np.random.RandomState(0) 

106 key, xs = space.random(key, 10000) 

107 

108 # Check that the eigenfunctions are orthonormal by comparing a Monte Carlo 

109 # approximation of the inner product with the identity matrix. 

110 check_function_with_backend( 

111 backend, 

112 np.eye(eigenfunctions.num_eigenfunctions), 

113 lambda xs: B.matmul(B.T(eigenfunctions(xs)), eigenfunctions(xs)) / xs.shape[0], 

114 xs, 

115 atol=0.4, # very loose, but helps make sure the diagonal is close to 1 while the rest is close to 0 

116 ) 

117 

118 

119@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

120def test_weighted_outerproduct_with_addition_theorem(inputs, backend): 

121 space, eigenfunctions, X, X2, weights = inputs 

122 

123 if isinstance(space, (CompactMatrixLieGroup, HammingGraph)): 

124 space_name = type(space).__name__ 

125 pytest.skip(f"{space_name} does not currently support eigenfunction evaluation") 

126 

127 chained_weights = chain( 

128 weights.squeeze(), eigenfunctions.num_eigenfunctions_per_level 

129 ) 

130 

131 Phi_X = eigenfunctions(X) 

132 Phi_X2 = eigenfunctions(X2) 

133 result = einsum("ni,ki,i->nk", Phi_X, Phi_X2, chained_weights) 

134 

135 # Check that `weighted_outerproduct`, which is based on the addition theorem, 

136 # returns the same result as the direct computation involving individual 

137 # eigenfunctions. 

138 check_function_with_backend( 

139 backend, result, eigenfunctions.weighted_outerproduct, weights, X, X2, atol=1e-2 

140 ) 

141 

142 

143@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

144def test_weighted_outerproduct_with_addition_theorem_one_input(inputs, backend): 

145 _, eigenfunctions, X, _, weights = inputs 

146 

147 result = eigenfunctions.weighted_outerproduct(weights, X, X) 

148 

149 # Check that `weighted_outerproduct`, when given only X (but not X2), 

150 # uses X2=X. 

151 check_function_with_backend( 

152 backend, 

153 result, 

154 eigenfunctions.weighted_outerproduct, 

155 weights, 

156 X, 

157 atol=1e-2, 

158 ) 

159 

160 

161@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

162def test_weighted_outerproduct_diag(inputs, backend): 

163 _, eigenfunctions, X, _, weights = inputs 

164 

165 result = np.diag(eigenfunctions.weighted_outerproduct(weights, X, X)) 

166 

167 # Check that `weighted_outerproduct_diag` returns the same result as the 

168 # diagonal of the full `weighted_outerproduct`. 

169 check_function_with_backend( 

170 backend, 

171 result, 

172 eigenfunctions.weighted_outerproduct_diag, 

173 weights, 

174 X, 

175 ) 

176 

177 

178@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

179def test_weighted_outerproduct_against_phi_product(inputs, backend): 

180 _, eigenfunctions, X, X2, weights = inputs 

181 

182 sum_phi_phi_for_level = eigenfunctions.phi_product(X, X2) 

183 

184 result = np.einsum("id,...nki->...nk", weights, sum_phi_phi_for_level) 

185 

186 # Check that `weighted_outerproduct` returns the weighted sum of `phi_product`. 

187 check_function_with_backend( 

188 backend, result, eigenfunctions.weighted_outerproduct, weights, X, X2 

189 ) 

190 

191 

192@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

193def test_weighted_outerproduct_diag_against_phi_product(inputs, backend): 

194 _, eigenfunctions, X, _, weights = inputs 

195 

196 phi_product_diag = eigenfunctions.phi_product_diag(X) 

197 

198 result = np.einsum("id,ni->n", weights, phi_product_diag) # [N,] 

199 

200 # Check that `weighted_outerproduct_diag` returns the weighted sum 

201 # of `phi_product_diag`. 

202 check_function_with_backend( 

203 backend, result, eigenfunctions.weighted_outerproduct_diag, weights, X 

204 )