Coverage for tests/utils/test_kernel_formulas.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1from math import log, tanh 

2 

3import numpy as np 

4import pytest 

5from sklearn.metrics.pairwise import rbf_kernel 

6 

7from geometric_kernels.spaces import HypercubeGraph 

8from geometric_kernels.utils.kernel_formulas import hypercube_graph_heat_kernel 

9 

10from ..helper import check_function_with_backend 

11 

12 

13@pytest.mark.parametrize("d", [1, 5, 10]) 

14@pytest.mark.parametrize("lengthscale", [1.0, 5.0, 10.0]) 

15@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

16def test_hypercube_graph_heat_kernel(d, lengthscale, backend): 

17 space = HypercubeGraph(d) 

18 

19 key = np.random.RandomState(0) 

20 N, N2 = key.randint(low=1, high=min(2**d, 10) + 1, size=2) 

21 key, X = space.random(key, N) 

22 key, X2 = space.random(key, N2) 

23 

24 gamma = -log(tanh(lengthscale**2 / 2)) 

25 result = rbf_kernel(X, X2, gamma=gamma) 

26 

27 def heat_kernel(lengthscale, X, X2): 

28 return hypercube_graph_heat_kernel( 

29 lengthscale, X, X2, normalized_laplacian=False 

30 ) 

31 

32 # Checks that the heat kernel on the hypercube graph coincides with the RBF 

33 # restricted onto binary vectors, with appropriately redefined length scale. 

34 check_function_with_backend( 

35 backend, 

36 result, 

37 heat_kernel, 

38 np.array([lengthscale]), 

39 X, 

40 X2, 

41 atol=1e-2, 

42 ) 

43 

44 if d > 5: 

45 X_first = X[0:1, :3] 

46 X2_first = X2[0:1, :3] 

47 X_second = X[0:1, 3:] 

48 X2_second = X2[0:1, 3:] 

49 

50 K_first = hypercube_graph_heat_kernel( 

51 np.array([lengthscale]), X_first, X2_first, normalized_laplacian=False 

52 ) 

53 K_second = hypercube_graph_heat_kernel( 

54 np.array([lengthscale]), X_second, X2_second, normalized_laplacian=False 

55 ) 

56 

57 result = K_first * K_second 

58 

59 # Checks that the heat kernel of the product is equal to the product 

60 # of heat kernels. 

61 check_function_with_backend( 

62 backend, 

63 result, 

64 heat_kernel, 

65 np.array([lengthscale]), 

66 X[0:1, :], 

67 X2[0:1, :], 

68 )