Coverage for geometric_kernels / kernels / matern_kernel_hamming_graph.py: 71%

35 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-15 13:41 +0000

1r""" 

2This module provides the :class:`MaternKernelHammingGraph` kernel, a subclass of 

3:class:`MaternKarhunenLoeveKernel` for :class:`HammingGraph` and 

4:class:`HypercubeGraph` spaces implementing the closed-form formula for the 

5heat kernel when $\nu = \infty$. 

6""" 

7 

8import lab as B 

9import numpy as np 

10from beartype.typing import Dict, Optional, Union 

11 

12from geometric_kernels.kernels.karhunen_loeve import MaternKarhunenLoeveKernel 

13from geometric_kernels.spaces.eigenfunctions import Eigenfunctions 

14from geometric_kernels.spaces.hamming_graph import HammingGraph 

15from geometric_kernels.spaces.hypercube_graph import HypercubeGraph 

16from geometric_kernels.utils.kernel_formulas.hamming_graph import ( 

17 hamming_graph_heat_kernel, 

18) 

19from geometric_kernels.utils.utils import _check_1_vector, _check_field_in_params 

20 

21 

22class MaternKernelHammingGraph(MaternKarhunenLoeveKernel): 

23 r""" 

24 For $\nu = \infty$, there exists a closed-form formula for the heat kernel 

25 on hamming graphs :class:`HammingGraph` (including the binary hypercube case 

26 :class:`HypercubeGraph`). This class extends :class:`MaternKarhunenLoeveKernel` 

27 to implement this formula in the case of $\nu = \infty$ for efficiency. 

28 

29 .. note:: 

30 We only use the closed form expression if `num_levels` is `d + 1` which 

31 corresponds to exact computation. When truncated to fewer levels, we 

32 must use the parent class implementation to ensure consistency with 

33 feature map approximations. 

34 """ 

35 

36 def __init__( 

37 self, 

38 space: Union[HammingGraph, HypercubeGraph], 

39 num_levels: int, 

40 normalize: bool = True, 

41 eigenvalues_laplacian: Optional[B.Numeric] = None, 

42 eigenfunctions: Optional[Eigenfunctions] = None, 

43 ): 

44 if not isinstance(space, (HammingGraph, HypercubeGraph)): 

45 raise ValueError( 

46 f"`space` must be an instance of HammingGraph or HypercubeGraph, but got {type(space)}" 

47 ) 

48 

49 super().__init__( 

50 space, 

51 num_levels, 

52 normalize, 

53 eigenvalues_laplacian, 

54 eigenfunctions, 

55 ) 

56 

57 def K( 

58 self, 

59 params: Dict[str, B.Numeric], 

60 X: B.Numeric, 

61 X2: Optional[B.Numeric] = None, 

62 **kwargs, 

63 ) -> B.Numeric: 

64 _check_field_in_params(params, "lengthscale") 

65 _check_1_vector(params["lengthscale"], 'params["lengthscale"]') 

66 

67 _check_field_in_params(params, "nu") 

68 _check_1_vector(params["nu"], 'params["nu"]') 

69 

70 if B.all(params["nu"] == np.inf): 

71 d = X.shape[-1] 

72 

73 # Only use fast path when we have all levels (exact computation) 

74 if self.num_levels == d + 1: 

75 # Get q from space (HammingGraph has n_cat, HypercubeGraph is binary q=2) 

76 q = getattr(self.space, "n_cat", 2) 

77 

78 return hamming_graph_heat_kernel(params["lengthscale"], X, X2, q=q) 

79 

80 return super().K(params, X, X2, **kwargs) 

81 

82 def K_diag(self, params: Dict[str, B.Numeric], X: B.Numeric, **kwargs) -> B.Numeric: 

83 _check_field_in_params(params, "lengthscale") 

84 _check_1_vector(params["lengthscale"], 'params["lengthscale"]') 

85 

86 _check_field_in_params(params, "nu") 

87 _check_1_vector(params["nu"], 'params["nu"]') 

88 

89 if B.all(params["nu"] == np.inf): 

90 d = X.shape[-1] 

91 

92 # Only use fast path when we have all levels (exact computation) 

93 if self.num_levels == d + 1: 

94 return B.ones(B.dtype(params["nu"]), X.shape[0]) 

95 

96 return super().K_diag(params, X, **kwargs)