Coverage for tests/spaces/test_product_discrete_spectrum_space.py: 100%
26 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import numpy as np
2import pytest
4from geometric_kernels.kernels import MaternGeometricKernel
5from geometric_kernels.spaces import (
6 Circle,
7 ProductDiscreteSpectrumSpace,
8 SpecialUnitary,
9)
10from geometric_kernels.utils.product import make_product
12from ..helper import check_function_with_backend
14_NUM_LEVELS = 20
17@pytest.mark.parametrize(
18 "factor1, factor2", [(Circle(), Circle()), (Circle(), SpecialUnitary(2))], ids=str
19)
20@pytest.mark.parametrize("lengthscale", [0.1, 0.5, 1.0, 2.0, 5.0])
21@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
22def test_heat_kernel_is_product_of_heat_kernels(factor1, factor2, lengthscale, backend):
23 product = ProductDiscreteSpectrumSpace(
24 factor1, factor2, num_levels=_NUM_LEVELS**2, num_levels_per_space=_NUM_LEVELS
25 )
27 key = np.random.RandomState(0)
28 key, xs_factor1 = factor1.random(key, 10)
29 key, xs_factor2 = factor2.random(key, 10)
31 kernel_product = MaternGeometricKernel(product, num=_NUM_LEVELS**2)
32 kernel_factor1 = MaternGeometricKernel(factor1, num=_NUM_LEVELS)
33 kernel_factor2 = MaternGeometricKernel(factor2, num=_NUM_LEVELS)
35 def K_diff(nu, lengthscale, xs_factor1, xs_factor2):
36 params = {"nu": nu, "lengthscale": lengthscale}
38 xs_product = make_product([xs_factor1, xs_factor2])
40 K_product = kernel_product.K(params, xs_product, xs_product)
41 K_factor1 = kernel_factor1.K(params, xs_factor1, xs_factor1)
42 K_factor2 = kernel_factor2.K(params, xs_factor2, xs_factor2)
44 return K_product - K_factor1 * K_factor2
46 # Check that the heat kernel on the ProductDiscreteSpectrumSpace coincides
47 # with the product of heat kernels on the factors.
48 check_function_with_backend(
49 backend,
50 np.zeros((xs_factor1.shape[0], xs_factor2.shape[0])),
51 K_diff,
52 np.array([np.inf]),
53 np.array([lengthscale]),
54 xs_factor1,
55 xs_factor2,
56 )