Coverage for tests/spaces/test_eigenfunctions.py: 100%
80 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import jax
2import lab as B
3import numpy as np
4import pytest
5from opt_einsum import contract as einsum
7from geometric_kernels.kernels.matern_kernel import default_num
8from geometric_kernels.spaces import CompactMatrixLieGroup, Hypersphere, Mesh
9from geometric_kernels.utils.utils import chain
11from ..helper import check_function_with_backend, discrete_spectrum_spaces
13jax.config.update("jax_enable_x64", True) # enable float64 in JAX
16@pytest.fixture(
17 params=discrete_spectrum_spaces(),
18 ids=str,
19)
20def inputs(request):
21 """
22 Returns a tuple (space, eigenfunctions, X, X2, weights) where:
23 - space = request.param,
24 - eigenfunctions = space.get_eigenfunctions(num_levels), with reasonable num_levels
25 - X is a random sample of random size from the space,
26 - X2 is another random sample of random size from the space,
27 - weights is an array of positive numbers of shape (eigenfunctions.num_levels, 1).
28 """
29 space = request.param
30 num_levels = default_num(space)
31 if isinstance(space, Hypersphere):
32 # For Hypersphere, the maximal number of levels with eigenfunction
33 # evaluation support is stored in the num_computed_levels field. We
34 # do not use more levels to be able to test eigenfunction evaluation.
35 num_levels = min(
36 10, num_levels, space.get_eigenfunctions(num_levels).num_computed_levels
37 )
38 elif isinstance(space, Mesh):
39 # We limit the number of levels to 50 to avoid excessive computation
40 # and increased numerical error for higher order eigenfunctions.
41 num_levels = min(50, num_levels)
42 eigenfunctions = space.get_eigenfunctions(num_levels)
44 key = np.random.RandomState(0)
45 N, N2 = key.randint(low=1, high=100 + 1, size=2)
46 key, X = space.random(key, N)
47 key, X2 = space.random(key, N2)
49 # These weights are used for testing the weighted outerproduct, they
50 # should be positive.
51 weights = np.random.rand(eigenfunctions.num_levels, 1) ** 2 + 1e-5
53 return space, eigenfunctions, X, X2, weights
56@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
57def test_call_eigenfunctions(inputs, backend):
58 space, eigenfunctions, X, _, _ = inputs
60 if isinstance(space, CompactMatrixLieGroup):
61 pytest.skip(
62 "CompactMatrixLieGroup subclasses do not currently support eigenfunction evaluation"
63 )
65 # Check that the eigenfunctions can be called, returning the right type and shape.
66 check_function_with_backend(
67 backend,
68 (X.shape[0], eigenfunctions.num_eigenfunctions),
69 eigenfunctions,
70 X,
71 compare_to_result=lambda res, f_out: f_out.shape == res,
72 )
75def test_numbers_of_eigenfunctions(inputs):
76 _, eigenfunctions, _, _, _ = inputs
77 num_levels = eigenfunctions.num_levels
78 # Check that the length of the `num_eigenfunctions_per_level` list is correct.
79 assert len(eigenfunctions.num_eigenfunctions_per_level) == num_levels
80 # Check that the first eigenspace is 1-dimensional.
81 assert eigenfunctions.num_eigenfunctions_per_level[0] == 1
83 # Check that dimensions of eigenspaces are always positive.
84 for i in range(num_levels):
85 assert eigenfunctions.num_eigenfunctions_per_level[i] > 0
87 num_eigenfunctions_manual = sum(eigenfunctions.num_eigenfunctions_per_level)
88 # Check that `num_eigenfunctions_per_level` sum up to the total number of
89 # eigenfunctions.
90 assert num_eigenfunctions_manual == eigenfunctions.num_eigenfunctions
93@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
94def test_orthonormality(inputs, backend):
95 space, eigenfunctions, _, _, _ = inputs
97 if isinstance(space, CompactMatrixLieGroup):
98 pytest.skip(
99 "CompactMatrixLieGroup subclasses do not currently support eigenfunction evaluation"
100 )
102 key = np.random.RandomState(0)
103 key, xs = space.random(key, 10000)
105 # Check that the eigenfunctions are orthonormal by comparing a Monte Carlo
106 # approximation of the inner product with the identity matrix.
107 check_function_with_backend(
108 backend,
109 np.eye(eigenfunctions.num_eigenfunctions),
110 lambda xs: B.matmul(B.T(eigenfunctions(xs)), eigenfunctions(xs)) / xs.shape[0],
111 xs,
112 atol=0.4, # very loose, but helps make sure the diagonal is close to 1 while the rest is close to 0
113 )
116@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
117def test_weighted_outerproduct_with_addition_theorem(inputs, backend):
118 space, eigenfunctions, X, X2, weights = inputs
120 if isinstance(space, CompactMatrixLieGroup):
121 pytest.skip(
122 "CompactMatrixLieGroup subclasses do not currently support eigenfunction evaluation"
123 )
125 chained_weights = chain(
126 weights.squeeze(), eigenfunctions.num_eigenfunctions_per_level
127 )
129 Phi_X = eigenfunctions(X)
130 Phi_X2 = eigenfunctions(X2)
131 result = einsum("ni,ki,i->nk", Phi_X, Phi_X2, chained_weights)
133 # Check that `weighted_outerproduct`, which is based on the addition theorem,
134 # returns the same result as the direct computation involving individual
135 # eigenfunctions.
136 check_function_with_backend(
137 backend, result, eigenfunctions.weighted_outerproduct, weights, X, X2, atol=1e-2
138 )
141@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
142def test_weighted_outerproduct_with_addition_theorem_one_input(inputs, backend):
143 _, eigenfunctions, X, _, weights = inputs
145 result = eigenfunctions.weighted_outerproduct(weights, X, X)
147 # Check that `weighted_outerproduct`, when given only X (but not X2),
148 # uses X2=X.
149 check_function_with_backend(
150 backend,
151 result,
152 eigenfunctions.weighted_outerproduct,
153 weights,
154 X,
155 atol=1e-2,
156 )
159@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
160def test_weighted_outerproduct_diag(inputs, backend):
161 _, eigenfunctions, X, _, weights = inputs
163 result = np.diag(eigenfunctions.weighted_outerproduct(weights, X, X))
165 # Check that `weighted_outerproduct_diag` returns the same result as the
166 # diagonal of the full `weighted_outerproduct`.
167 check_function_with_backend(
168 backend,
169 result,
170 eigenfunctions.weighted_outerproduct_diag,
171 weights,
172 X,
173 )
176@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
177def test_weighted_outerproduct_against_phi_product(inputs, backend):
178 _, eigenfunctions, X, X2, weights = inputs
180 sum_phi_phi_for_level = eigenfunctions.phi_product(X, X2)
182 result = np.einsum("id,...nki->...nk", weights, sum_phi_phi_for_level)
184 # Check that `weighted_outerproduct` returns the weighted sum of `phi_product`.
185 check_function_with_backend(
186 backend, result, eigenfunctions.weighted_outerproduct, weights, X, X2
187 )
190@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
191def test_weighted_outerproduct_diag_against_phi_product(inputs, backend):
192 _, eigenfunctions, X, _, weights = inputs
194 phi_product_diag = eigenfunctions.phi_product_diag(X)
196 result = np.einsum("id,ni->n", weights, phi_product_diag) # [N,]
198 # Check that `weighted_outerproduct_diag` returns the weighted sum
199 # of `phi_product_diag`.
200 check_function_with_backend(
201 backend, result, eigenfunctions.weighted_outerproduct_diag, weights, X
202 )