Coverage for tests / spaces / test_hamming_graph.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-15 13:41 +0000

1import lab as B 

2import numpy as np 

3import pytest 

4from plum import Tuple 

5 

6from geometric_kernels.kernels import MaternGeometricKernel 

7from geometric_kernels.spaces import HammingGraph, HypercubeGraph 

8from geometric_kernels.utils.kernel_formulas import hamming_graph_heat_kernel 

9 

10from ..helper import check_function_with_backend 

11 

12 

13@pytest.fixture(params=[(1, 2), (2, 2), (5, 2), (10, 2), (10, 4)]) 

14def inputs(request) -> Tuple[B.Numeric]: 

15 """ 

16 Returns a tuple (space, eigenfunctions, X, X2, weights) where: 

17 - space is a HammingGraph object with (dim, n_cat) equal to request.param, 

18 - eigenfunctions is the respective Eigenfunctions object with at most 5 levels, 

19 - X is a random sample of random size from the space, 

20 - X2 is another random sample of random size from the space, 

21 - weights is an array of positive numbers of shape (eigenfunctions.num_levels, 1). 

22 """ 

23 d, q = request.param 

24 space = HammingGraph(dim=d, n_cat=q) 

25 eigenfunctions = space.get_eigenfunctions(min(space.dim + 1, 5)) 

26 

27 key = np.random.RandomState(0) 

28 N, N2 = key.randint(low=1, high=min(q**d, 10) + 1, size=2) 

29 key, X = space.random(key, N) 

30 key, X2 = space.random(key, N2) 

31 

32 # These weights are used for testing the weighted outerproduct, they 

33 # should be positive. 

34 weights = np.random.rand(eigenfunctions.num_levels, 1) ** 2 + 1e-5 

35 

36 return space, eigenfunctions, X, X2, weights 

37 

38 

39def test_numbers_of_eigenfunctions(inputs): 

40 space, eigenfunctions, _, _, _ = inputs 

41 num_levels = eigenfunctions.num_levels 

42 

43 # If the number of levels is maximal, check that the number of 

44 # eigenfunctions is equal to the number of categorical vectors. 

45 if num_levels == space.dim + 1: 

46 assert eigenfunctions.num_eigenfunctions == space.n_cat**space.dim 

47 

48 

49@pytest.mark.parametrize("nu", [1.5, np.inf]) 

50@pytest.mark.parametrize("lengthscale", [1.0, 5.0, 10.0]) 

51@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

52def test_reduces_to_hypercube_when_q_equals_2(inputs, nu, lengthscale, backend): 

53 space, eigenfunctions, X, X2, _ = inputs 

54 

55 if space.n_cat != 2: 

56 pytest.skip("Only applicable when n_cat=2") 

57 

58 hypercube = HypercubeGraph(space.dim) 

59 X_bool = X.astype(bool) 

60 X2_bool = X2.astype(bool) 

61 

62 # Compare eigenvalues (backend-agnostic, only once per backend) 

63 if lengthscale == 1.0 and nu == 1.5: 

64 hamming_eigenvalues = space.get_eigenvalues(eigenfunctions.num_levels) 

65 hypercube_eigenvalues = hypercube.get_eigenvalues(eigenfunctions.num_levels) 

66 np.testing.assert_allclose( 

67 hamming_eigenvalues, hypercube_eigenvalues, rtol=1e-10 

68 ) 

69 

70 # Compare kernel values with backend testing 

71 kernel_hamming = MaternGeometricKernel(space) 

72 kernel_hypercube = MaternGeometricKernel(hypercube) 

73 

74 params = {"nu": np.array([nu]), "lengthscale": np.array([lengthscale])} 

75 K_hypercube = kernel_hypercube.K(params, X_bool, X2_bool) 

76 

77 check_function_with_backend( 

78 backend, 

79 K_hypercube, 

80 kernel_hamming.K, 

81 params, 

82 X, 

83 X2, 

84 atol=1e-2, 

85 ) 

86 

87 

88@pytest.mark.parametrize("lengthscale", [1.0, 5.0, 10.0]) 

89@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

90def test_against_analytic_heat_kernel(inputs, lengthscale, backend): 

91 space, _, X, X2, _ = inputs 

92 lengthscale = np.array([lengthscale]) 

93 result = hamming_graph_heat_kernel(lengthscale, X, X2, q=space.n_cat) 

94 

95 kernel = MaternGeometricKernel(space) 

96 

97 # Check that MaternGeometricKernel on HammingGraph with nu=infinity 

98 # coincides with the closed form expression for the heat kernel on the 

99 # Hamming graph. 

100 check_function_with_backend( 

101 backend, 

102 result, 

103 kernel.K, 

104 {"nu": np.array([np.inf]), "lengthscale": lengthscale}, 

105 X, 

106 X2, 

107 atol=1e-2, 

108 )