Coverage for geometric_kernels / utils / kernel_formulas / hamming_graph.py: 95%

20 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-15 13:41 +0000

1""" 

2Implements the closed form expression for the heat kernel on the Hamming graph. 

3 

4The implementation is provided mainly for testing purposes. 

5""" 

6 

7from math import sqrt 

8 

9import lab as B 

10from beartype.typing import Optional 

11 

12from geometric_kernels.lab_extras import float_like 

13from geometric_kernels.utils.utils import _check_1_vector, _check_matrix 

14 

15 

16def hamming_graph_heat_kernel( 

17 lengthscale: B.Numeric, 

18 X: B.Numeric, 

19 X2: Optional[B.Numeric] = None, 

20 q: int = 2, 

21 normalized_laplacian: bool = True, 

22): 

23 """ 

24 Analytic formula for the heat kernel on the Hamming graph, see 

25 Equation (3) in :cite:t:`doumont2025`. 

26 

27 :param lengthscale: 

28 The length scale of the kernel, an array of shape [1]. 

29 :param X: 

30 A batch of inputs, an array of shape [N, d]. 

31 :param X2: 

32 A batch of inputs, an array of shape [N2, d]. If None, defaults to X. 

33 :param q: 

34 The alphabet size (number of categories). Defaults to 2 (hypercube). 

35 :param normalized_laplacian: 

36 Whether to use normalized Laplacian scaling. 

37 

38 :return: 

39 The kernel matrix, an array of shape [N, N2]. 

40 """ 

41 if X2 is None: 

42 X2 = X 

43 

44 _check_1_vector(lengthscale, "lengthscale") 

45 _check_matrix(X, "X") 

46 _check_matrix(X2, "X2") 

47 

48 d = X.shape[-1] 

49 

50 if normalized_laplacian: 

51 lengthscale = lengthscale / sqrt(d * (q - 1)) 

52 

53 beta = lengthscale**2 / 2 

54 

55 # Compute disagreement indicator: 1 when coordinates differ, 0 when they match 

56 # Shape: [N, N2, d] 

57 disagreement = B.cast(float_like(X), X[:, None, :] != X2[None, :, :]) 

58 

59 exp_neg_beta_q = B.exp(-beta * q) 

60 factor_disagree = (1 - exp_neg_beta_q) / (1 + (q - 1) * exp_neg_beta_q) 

61 log_kernel = B.sum(B.log(factor_disagree) * disagreement, axis=-1) 

62 

63 return B.exp(log_kernel) # Shape: [N, N2]