Source code for geometric_kernels.kernels.matern_kernel_hamming_graph

r"""
This module provides the :class:`MaternKernelHammingGraph` kernel, a subclass of
:class:`MaternKarhunenLoeveKernel` for :class:`HammingGraph` and
:class:`HypercubeGraph` spaces implementing the closed-form formula for the
heat kernel when $\nu = \infty$.
"""

import lab as B
import numpy as np
from beartype.typing import Dict, Optional, Union

from geometric_kernels.kernels.karhunen_loeve import MaternKarhunenLoeveKernel
from geometric_kernels.spaces.eigenfunctions import Eigenfunctions
from geometric_kernels.spaces.hamming_graph import HammingGraph
from geometric_kernels.spaces.hypercube_graph import HypercubeGraph
from geometric_kernels.utils.kernel_formulas.hamming_graph import (
    hamming_graph_heat_kernel,
)
from geometric_kernels.utils.utils import _check_1_vector, _check_field_in_params


[docs] class MaternKernelHammingGraph(MaternKarhunenLoeveKernel): r""" For $\nu = \infty$, there exists a closed-form formula for the heat kernel on hamming graphs :class:`HammingGraph` (including the binary hypercube case :class:`HypercubeGraph`). This class extends :class:`MaternKarhunenLoeveKernel` to implement this formula in the case of $\nu = \infty$ for efficiency. .. note:: We only use the closed form expression if `num_levels` is `d + 1` which corresponds to exact computation. When truncated to fewer levels, we must use the parent class implementation to ensure consistency with feature map approximations. """ def __init__( self, space: Union[HammingGraph, HypercubeGraph], num_levels: int, normalize: bool = True, eigenvalues_laplacian: Optional[B.Numeric] = None, eigenfunctions: Optional[Eigenfunctions] = None, ): if not isinstance(space, (HammingGraph, HypercubeGraph)): raise ValueError( f"`space` must be an instance of HammingGraph or HypercubeGraph, but got {type(space)}" ) super().__init__( space, num_levels, normalize, eigenvalues_laplacian, eigenfunctions, )
[docs] def K( self, params: Dict[str, B.Numeric], X: B.Numeric, X2: Optional[B.Numeric] = None, **kwargs, ) -> B.Numeric: _check_field_in_params(params, "lengthscale") _check_1_vector(params["lengthscale"], 'params["lengthscale"]') _check_field_in_params(params, "nu") _check_1_vector(params["nu"], 'params["nu"]') if B.all(params["nu"] == np.inf): d = X.shape[-1] # Only use fast path when we have all levels (exact computation) if self.num_levels == d + 1: # Get q from space (HammingGraph has n_cat, HypercubeGraph is binary q=2) q = getattr(self.space, "n_cat", 2) return hamming_graph_heat_kernel(params["lengthscale"], X, X2, q=q) return super().K(params, X, X2, **kwargs)
[docs] def K_diag(self, params: Dict[str, B.Numeric], X: B.Numeric, **kwargs) -> B.Numeric: _check_field_in_params(params, "lengthscale") _check_1_vector(params["lengthscale"], 'params["lengthscale"]') _check_field_in_params(params, "nu") _check_1_vector(params["nu"], 'params["nu"]') if B.all(params["nu"] == np.inf): d = X.shape[-1] # Only use fast path when we have all levels (exact computation) if self.num_levels == d + 1: return B.ones(B.dtype(params["nu"]), X.shape[0]) return super().K_diag(params, X, **kwargs)