Coverage for geometric_kernels / kernels / matern_kernel_hamming_graph.py: 71%
35 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
1r"""
2This module provides the :class:`MaternKernelHammingGraph` kernel, a subclass of
3:class:`MaternKarhunenLoeveKernel` for :class:`HammingGraph` and
4:class:`HypercubeGraph` spaces implementing the closed-form formula for the
5heat kernel when $\nu = \infty$.
6"""
8import lab as B
9import numpy as np
10from beartype.typing import Dict, Optional, Union
12from geometric_kernels.kernels.karhunen_loeve import MaternKarhunenLoeveKernel
13from geometric_kernels.spaces.eigenfunctions import Eigenfunctions
14from geometric_kernels.spaces.hamming_graph import HammingGraph
15from geometric_kernels.spaces.hypercube_graph import HypercubeGraph
16from geometric_kernels.utils.kernel_formulas.hamming_graph import (
17 hamming_graph_heat_kernel,
18)
19from geometric_kernels.utils.utils import _check_1_vector, _check_field_in_params
22class MaternKernelHammingGraph(MaternKarhunenLoeveKernel):
23 r"""
24 For $\nu = \infty$, there exists a closed-form formula for the heat kernel
25 on hamming graphs :class:`HammingGraph` (including the binary hypercube case
26 :class:`HypercubeGraph`). This class extends :class:`MaternKarhunenLoeveKernel`
27 to implement this formula in the case of $\nu = \infty$ for efficiency.
29 .. note::
30 We only use the closed form expression if `num_levels` is `d + 1` which
31 corresponds to exact computation. When truncated to fewer levels, we
32 must use the parent class implementation to ensure consistency with
33 feature map approximations.
34 """
36 def __init__(
37 self,
38 space: Union[HammingGraph, HypercubeGraph],
39 num_levels: int,
40 normalize: bool = True,
41 eigenvalues_laplacian: Optional[B.Numeric] = None,
42 eigenfunctions: Optional[Eigenfunctions] = None,
43 ):
44 if not isinstance(space, (HammingGraph, HypercubeGraph)):
45 raise ValueError(
46 f"`space` must be an instance of HammingGraph or HypercubeGraph, but got {type(space)}"
47 )
49 super().__init__(
50 space,
51 num_levels,
52 normalize,
53 eigenvalues_laplacian,
54 eigenfunctions,
55 )
57 def K(
58 self,
59 params: Dict[str, B.Numeric],
60 X: B.Numeric,
61 X2: Optional[B.Numeric] = None,
62 **kwargs,
63 ) -> B.Numeric:
64 _check_field_in_params(params, "lengthscale")
65 _check_1_vector(params["lengthscale"], 'params["lengthscale"]')
67 _check_field_in_params(params, "nu")
68 _check_1_vector(params["nu"], 'params["nu"]')
70 if B.all(params["nu"] == np.inf):
71 d = X.shape[-1]
73 # Only use fast path when we have all levels (exact computation)
74 if self.num_levels == d + 1:
75 # Get q from space (HammingGraph has n_cat, HypercubeGraph is binary q=2)
76 q = getattr(self.space, "n_cat", 2)
78 return hamming_graph_heat_kernel(params["lengthscale"], X, X2, q=q)
80 return super().K(params, X, X2, **kwargs)
82 def K_diag(self, params: Dict[str, B.Numeric], X: B.Numeric, **kwargs) -> B.Numeric:
83 _check_field_in_params(params, "lengthscale")
84 _check_1_vector(params["lengthscale"], 'params["lengthscale"]')
86 _check_field_in_params(params, "nu")
87 _check_1_vector(params["nu"], 'params["nu"]')
89 if B.all(params["nu"] == np.inf):
90 d = X.shape[-1]
92 # Only use fast path when we have all levels (exact computation)
93 if self.num_levels == d + 1:
94 return B.ones(B.dtype(params["nu"]), X.shape[0])
96 return super().K_diag(params, X, **kwargs)