Coverage for geometric_kernels / utils / kernel_formulas / hamming_graph.py: 95%
20 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
1"""
2Implements the closed form expression for the heat kernel on the Hamming graph.
4The implementation is provided mainly for testing purposes.
5"""
7from math import sqrt
9import lab as B
10from beartype.typing import Optional
12from geometric_kernels.lab_extras import float_like
13from geometric_kernels.utils.utils import _check_1_vector, _check_matrix
16def hamming_graph_heat_kernel(
17 lengthscale: B.Numeric,
18 X: B.Numeric,
19 X2: Optional[B.Numeric] = None,
20 q: int = 2,
21 normalized_laplacian: bool = True,
22):
23 """
24 Analytic formula for the heat kernel on the Hamming graph, see
25 Equation (3) in :cite:t:`doumont2025`.
27 :param lengthscale:
28 The length scale of the kernel, an array of shape [1].
29 :param X:
30 A batch of inputs, an array of shape [N, d].
31 :param X2:
32 A batch of inputs, an array of shape [N2, d]. If None, defaults to X.
33 :param q:
34 The alphabet size (number of categories). Defaults to 2 (hypercube).
35 :param normalized_laplacian:
36 Whether to use normalized Laplacian scaling.
38 :return:
39 The kernel matrix, an array of shape [N, N2].
40 """
41 if X2 is None:
42 X2 = X
44 _check_1_vector(lengthscale, "lengthscale")
45 _check_matrix(X, "X")
46 _check_matrix(X2, "X2")
48 d = X.shape[-1]
50 if normalized_laplacian:
51 lengthscale = lengthscale / sqrt(d * (q - 1))
53 beta = lengthscale**2 / 2
55 # Compute disagreement indicator: 1 when coordinates differ, 0 when they match
56 # Shape: [N, N2, d]
57 disagreement = B.cast(float_like(X), X[:, None, :] != X2[None, :, :])
59 exp_neg_beta_q = B.exp(-beta * q)
60 factor_disagree = (1 - exp_neg_beta_q) / (1 + (q - 1) * exp_neg_beta_q)
61 log_kernel = B.sum(B.log(factor_disagree) * disagreement, axis=-1)
63 return B.exp(log_kernel) # Shape: [N, N2]