Coverage for tests/spaces/test_product_discrete_spectrum_space.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import numpy as np 

2import pytest 

3 

4from geometric_kernels.kernels import MaternGeometricKernel 

5from geometric_kernels.spaces import ( 

6 Circle, 

7 ProductDiscreteSpectrumSpace, 

8 SpecialUnitary, 

9) 

10from geometric_kernels.utils.product import make_product 

11 

12from ..helper import check_function_with_backend 

13 

14_NUM_LEVELS = 20 

15 

16 

17@pytest.mark.parametrize( 

18 "factor1, factor2", [(Circle(), Circle()), (Circle(), SpecialUnitary(2))], ids=str 

19) 

20@pytest.mark.parametrize("lengthscale", [0.1, 0.5, 1.0, 2.0, 5.0]) 

21@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

22def test_heat_kernel_is_product_of_heat_kernels(factor1, factor2, lengthscale, backend): 

23 product = ProductDiscreteSpectrumSpace( 

24 factor1, factor2, num_levels=_NUM_LEVELS**2, num_levels_per_space=_NUM_LEVELS 

25 ) 

26 

27 key = np.random.RandomState(0) 

28 key, xs_factor1 = factor1.random(key, 10) 

29 key, xs_factor2 = factor2.random(key, 10) 

30 

31 kernel_product = MaternGeometricKernel(product, num=_NUM_LEVELS**2) 

32 kernel_factor1 = MaternGeometricKernel(factor1, num=_NUM_LEVELS) 

33 kernel_factor2 = MaternGeometricKernel(factor2, num=_NUM_LEVELS) 

34 

35 def K_diff(nu, lengthscale, xs_factor1, xs_factor2): 

36 params = {"nu": nu, "lengthscale": lengthscale} 

37 

38 xs_product = make_product([xs_factor1, xs_factor2]) 

39 

40 K_product = kernel_product.K(params, xs_product, xs_product) 

41 K_factor1 = kernel_factor1.K(params, xs_factor1, xs_factor1) 

42 K_factor2 = kernel_factor2.K(params, xs_factor2, xs_factor2) 

43 

44 return K_product - K_factor1 * K_factor2 

45 

46 # Check that the heat kernel on the ProductDiscreteSpectrumSpace coincides 

47 # with the product of heat kernels on the factors. 

48 check_function_with_backend( 

49 backend, 

50 np.zeros((xs_factor1.shape[0], xs_factor2.shape[0])), 

51 K_diff, 

52 np.array([np.inf]), 

53 np.array([lengthscale]), 

54 xs_factor1, 

55 xs_factor2, 

56 )