Coverage for tests / spaces / test_eigenfunctions.py: 100%
83 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
1import jax
2import lab as B
3import numpy as np
4import pytest
5from opt_einsum import contract as einsum
7from geometric_kernels.kernels.matern_kernel import default_num
8from geometric_kernels.spaces import (
9 CompactMatrixLieGroup,
10 HammingGraph,
11 Hypersphere,
12 Mesh,
13)
14from geometric_kernels.utils.utils import chain
16from ..helper import check_function_with_backend, discrete_spectrum_spaces
18jax.config.update("jax_enable_x64", True) # enable float64 in JAX
21@pytest.fixture(
22 params=discrete_spectrum_spaces(),
23 ids=str,
24)
25def inputs(request):
26 """
27 Returns a tuple (space, eigenfunctions, X, X2, weights) where:
28 - space = request.param,
29 - eigenfunctions = space.get_eigenfunctions(num_levels), with reasonable num_levels
30 - X is a random sample of random size from the space,
31 - X2 is another random sample of random size from the space,
32 - weights is an array of positive numbers of shape (eigenfunctions.num_levels, 1).
33 """
34 space = request.param
35 num_levels = default_num(space)
36 if isinstance(space, Hypersphere):
37 # For Hypersphere, the maximal number of levels with eigenfunction
38 # evaluation support is stored in the num_computed_levels field. We
39 # do not use more levels to be able to test eigenfunction evaluation.
40 num_levels = min(
41 10, num_levels, space.get_eigenfunctions(num_levels).num_computed_levels
42 )
43 elif isinstance(space, Mesh):
44 # We limit the number of levels to 50 to avoid excessive computation
45 # and increased numerical error for higher order eigenfunctions.
46 num_levels = min(50, num_levels)
47 eigenfunctions = space.get_eigenfunctions(num_levels)
49 key = np.random.RandomState(0)
50 N, N2 = key.randint(low=1, high=100 + 1, size=2)
51 key, X = space.random(key, N)
52 key, X2 = space.random(key, N2)
54 # These weights are used for testing the weighted outerproduct, they
55 # should be positive.
56 weights = np.random.rand(eigenfunctions.num_levels, 1) ** 2 + 1e-5
58 return space, eigenfunctions, X, X2, weights
61@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
62def test_call_eigenfunctions(inputs, backend):
63 space, eigenfunctions, X, _, _ = inputs
65 if isinstance(space, (CompactMatrixLieGroup, HammingGraph)):
66 space_name = type(space).__name__
67 pytest.skip(f"{space_name} does not currently support eigenfunction evaluation")
69 # Check that the eigenfunctions can be called, returning the right type and shape.
70 check_function_with_backend(
71 backend,
72 (X.shape[0], eigenfunctions.num_eigenfunctions),
73 eigenfunctions,
74 X,
75 compare_to_result=lambda res, f_out: f_out.shape == res,
76 )
79def test_numbers_of_eigenfunctions(inputs):
80 _, eigenfunctions, _, _, _ = inputs
81 num_levels = eigenfunctions.num_levels
82 # Check that the length of the `num_eigenfunctions_per_level` list is correct.
83 assert len(eigenfunctions.num_eigenfunctions_per_level) == num_levels
84 # Check that the first eigenspace is 1-dimensional.
85 assert eigenfunctions.num_eigenfunctions_per_level[0] == 1
87 # Check that dimensions of eigenspaces are always positive.
88 for i in range(num_levels):
89 assert eigenfunctions.num_eigenfunctions_per_level[i] > 0
91 num_eigenfunctions_manual = sum(eigenfunctions.num_eigenfunctions_per_level)
92 # Check that `num_eigenfunctions_per_level` sum up to the total number of
93 # eigenfunctions.
94 assert num_eigenfunctions_manual == eigenfunctions.num_eigenfunctions
97@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
98def test_orthonormality(inputs, backend):
99 space, eigenfunctions, _, _, _ = inputs
101 if isinstance(space, (CompactMatrixLieGroup, HammingGraph)):
102 space_name = type(space).__name__
103 pytest.skip(f"{space_name} does not currently support eigenfunction evaluation")
105 key = np.random.RandomState(0)
106 key, xs = space.random(key, 10000)
108 # Check that the eigenfunctions are orthonormal by comparing a Monte Carlo
109 # approximation of the inner product with the identity matrix.
110 check_function_with_backend(
111 backend,
112 np.eye(eigenfunctions.num_eigenfunctions),
113 lambda xs: B.matmul(B.T(eigenfunctions(xs)), eigenfunctions(xs)) / xs.shape[0],
114 xs,
115 atol=0.4, # very loose, but helps make sure the diagonal is close to 1 while the rest is close to 0
116 )
119@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
120def test_weighted_outerproduct_with_addition_theorem(inputs, backend):
121 space, eigenfunctions, X, X2, weights = inputs
123 if isinstance(space, (CompactMatrixLieGroup, HammingGraph)):
124 space_name = type(space).__name__
125 pytest.skip(f"{space_name} does not currently support eigenfunction evaluation")
127 chained_weights = chain(
128 weights.squeeze(), eigenfunctions.num_eigenfunctions_per_level
129 )
131 Phi_X = eigenfunctions(X)
132 Phi_X2 = eigenfunctions(X2)
133 result = einsum("ni,ki,i->nk", Phi_X, Phi_X2, chained_weights)
135 # Check that `weighted_outerproduct`, which is based on the addition theorem,
136 # returns the same result as the direct computation involving individual
137 # eigenfunctions.
138 check_function_with_backend(
139 backend, result, eigenfunctions.weighted_outerproduct, weights, X, X2, atol=1e-2
140 )
143@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
144def test_weighted_outerproduct_with_addition_theorem_one_input(inputs, backend):
145 _, eigenfunctions, X, _, weights = inputs
147 result = eigenfunctions.weighted_outerproduct(weights, X, X)
149 # Check that `weighted_outerproduct`, when given only X (but not X2),
150 # uses X2=X.
151 check_function_with_backend(
152 backend,
153 result,
154 eigenfunctions.weighted_outerproduct,
155 weights,
156 X,
157 atol=1e-2,
158 )
161@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
162def test_weighted_outerproduct_diag(inputs, backend):
163 _, eigenfunctions, X, _, weights = inputs
165 result = np.diag(eigenfunctions.weighted_outerproduct(weights, X, X))
167 # Check that `weighted_outerproduct_diag` returns the same result as the
168 # diagonal of the full `weighted_outerproduct`.
169 check_function_with_backend(
170 backend,
171 result,
172 eigenfunctions.weighted_outerproduct_diag,
173 weights,
174 X,
175 )
178@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
179def test_weighted_outerproduct_against_phi_product(inputs, backend):
180 _, eigenfunctions, X, X2, weights = inputs
182 sum_phi_phi_for_level = eigenfunctions.phi_product(X, X2)
184 result = np.einsum("id,...nki->...nk", weights, sum_phi_phi_for_level)
186 # Check that `weighted_outerproduct` returns the weighted sum of `phi_product`.
187 check_function_with_backend(
188 backend, result, eigenfunctions.weighted_outerproduct, weights, X, X2
189 )
192@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
193def test_weighted_outerproduct_diag_against_phi_product(inputs, backend):
194 _, eigenfunctions, X, _, weights = inputs
196 phi_product_diag = eigenfunctions.phi_product_diag(X)
198 result = np.einsum("id,ni->n", weights, phi_product_diag) # [N,]
200 # Check that `weighted_outerproduct_diag` returns the weighted sum
201 # of `phi_product_diag`.
202 check_function_with_backend(
203 backend, result, eigenfunctions.weighted_outerproduct_diag, weights, X
204 )