Coverage for tests/utils/test_kernel_formulas.py: 100%
30 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1from math import log, tanh
3import numpy as np
4import pytest
5from sklearn.metrics.pairwise import rbf_kernel
7from geometric_kernels.spaces import HypercubeGraph
8from geometric_kernels.utils.kernel_formulas import hypercube_graph_heat_kernel
10from ..helper import check_function_with_backend
13@pytest.mark.parametrize("d", [1, 5, 10])
14@pytest.mark.parametrize("lengthscale", [1.0, 5.0, 10.0])
15@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
16def test_hypercube_graph_heat_kernel(d, lengthscale, backend):
17 space = HypercubeGraph(d)
19 key = np.random.RandomState(0)
20 N, N2 = key.randint(low=1, high=min(2**d, 10) + 1, size=2)
21 key, X = space.random(key, N)
22 key, X2 = space.random(key, N2)
24 gamma = -log(tanh(lengthscale**2 / 2))
25 result = rbf_kernel(X, X2, gamma=gamma)
27 def heat_kernel(lengthscale, X, X2):
28 return hypercube_graph_heat_kernel(
29 lengthscale, X, X2, normalized_laplacian=False
30 )
32 # Checks that the heat kernel on the hypercube graph coincides with the RBF
33 # restricted onto binary vectors, with appropriately redefined length scale.
34 check_function_with_backend(
35 backend,
36 result,
37 heat_kernel,
38 np.array([lengthscale]),
39 X,
40 X2,
41 atol=1e-2,
42 )
44 if d > 5:
45 X_first = X[0:1, :3]
46 X2_first = X2[0:1, :3]
47 X_second = X[0:1, 3:]
48 X2_second = X2[0:1, 3:]
50 K_first = hypercube_graph_heat_kernel(
51 np.array([lengthscale]), X_first, X2_first, normalized_laplacian=False
52 )
53 K_second = hypercube_graph_heat_kernel(
54 np.array([lengthscale]), X_second, X2_second, normalized_laplacian=False
55 )
57 result = K_first * K_second
59 # Checks that the heat kernel of the product is equal to the product
60 # of heat kernels.
61 check_function_with_backend(
62 backend,
63 result,
64 heat_kernel,
65 np.array([lengthscale]),
66 X[0:1, :],
67 X2[0:1, :],
68 )