Coverage for tests/kernels/test_product_kernel.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import numpy as np 

2import pytest 

3 

4from geometric_kernels.kernels import MaternGeometricKernel, ProductGeometricKernel 

5from geometric_kernels.spaces import Circle, SpecialUnitary 

6from geometric_kernels.utils.product import make_product 

7 

8from ..helper import check_function_with_backend 

9 

10 

11@pytest.mark.parametrize( 

12 "factor1, factor2", [(Circle(), Circle()), (Circle(), SpecialUnitary(2))], ids=str 

13) 

14@pytest.mark.parametrize("nu, lengthscale", [(1 / 2, 2.0), (5 / 2, 1.0), (np.inf, 0.1)]) 

15@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

16def test_kernel_is_product_of_heat_kernels(factor1, factor2, nu, lengthscale, backend): 

17 key = np.random.RandomState(0) 

18 key, xs_factor1 = factor1.random(key, 10) 

19 key, xs_factor2 = factor2.random(key, 10) 

20 

21 kernel_factor1 = MaternGeometricKernel(factor1) 

22 kernel_factor2 = MaternGeometricKernel(factor2) 

23 product_kernel = ProductGeometricKernel(kernel_factor1, kernel_factor2) 

24 

25 def K_diff(nu, lengthscale, xs_factor1, xs_factor2): 

26 params = {"nu": nu, "lengthscale": lengthscale} 

27 

28 xs_product = make_product([xs_factor1, xs_factor2]) 

29 

30 K_product = product_kernel.K(params, xs_product, xs_product) 

31 K_factor1 = kernel_factor1.K(params, xs_factor1, xs_factor1) 

32 K_factor2 = kernel_factor2.K(params, xs_factor2, xs_factor2) 

33 

34 return K_product - K_factor1 * K_factor2 

35 

36 # Check that ProductGeometricKernel without ARD coincides with the product 

37 # of the respective factor kernels. 

38 check_function_with_backend( 

39 backend, 

40 np.zeros((xs_factor1.shape[0], xs_factor2.shape[0])), 

41 K_diff, 

42 np.array([nu]), 

43 np.array([lengthscale]), 

44 xs_factor1, 

45 xs_factor2, 

46 )