Custom Spaces/Kernels

GeometricKernels ships with general-purpose kernel classes that cover most use cases, so you will rarely need to write a new kernel from scratch. In practice, you typically implement a new space and then reuse the existing kernels.

This notebook therefore focuses on implementing custom spaces. At the end, we briefly show how to define a custom kernel for advanced cases (e.g. a non-Matérn kernel or a Matérn kernel with space‑specific optimisation). For an overview of the built-in kernels, see the documentation: https://geometric-kernels.github.io/GeometricKernels/autoapi/geometric_kernels/kernels/index.html.

When to implement what:

  • Custom space: you have a new geometry and can provide the appropriate machinery (e.g. evaluating eignpairs of the Laplacian) for it.

  • Custom kernel: you need a non-Matérn kernel or a more efficient implementation tailored to your space.

Spaces

We recommend implementing a new GeometricKernels space by subclassing one of the existing abstract space classes:

  • DiscreteSpectrumSpace

  • NoncompactSymmetricSpace

  • HodgeDiscreteSpectrumSpace

Provide concrete implementations for all abstract methods required by the chosen base class. Doing so will make one of the standard kernel classes directly applicable to your new space:

  • MaternKarhunenLoeveKernel (for DiscreteSpectrumSpace)

  • MaternFeatureMapKernel (for NoncompactSymmetricSpace)

  • MaternHodgeCompositionalKernel (for HodgeDiscreteSpectrumSpace)

If you do not want to use the existing kernel machinery, you can instead subclass the most general Space class and implement a new kernel from scratch by subclassing BaseGeometricKernel.

Notes:

  • Product spaces: If your target space is a Cartesian product of two or more spaces, you typically do not need to implement a new space class. Instead, consider ProductGeometricKernel/ProductDiscreteSpectrumSpace. See the documentation on product spaces: https://geometric-kernels.github.io/GeometricKernels/theory/product_spaces.html

  • Specialising ``MaternGeometricKernel``: If you want to change how MaternGeometricKernel specialises to your space (e.g. default approximation level/num_levels), consult that class’s documentation and pass the relevant arguments explicitly when constructing the kernel, or change the kernels.matern_kernel module.

  • Contributing a new space: If you plan to submit a PR adding a new space—thank you! Please integrate it into the test suite and add space-specific tests. Run make lint and make test before opening a PR. GitHub CI will run these as well (on multiple platforms) and block merging until they pass, but running them locally first is faster.

As a toy example, we implement a space representing a complete graph with \(n\) nodes. In practice, this space could be represented using the existing general Graph space by providing it the adjacency matrix of the complete graph. However, here we build this space from scratch as if the general Graph space did not exist, to demonstrate the mechanics of implementing a new space.

Note: we do not strictly follow the project’s linting/style guidelines in this notebook for a more focused exposition.

Imports and writing backend‑agnostic code

GeometricKernels works with multiple backends: pytorch, jax, numpy, tensorflow. To write code once and have it work across these backends, we use the lab library. lab provides a thin dispatch layer: when you give it, say, PyTorch tensors as inputs, it returns PyTorch tensors; when you give it NumPy arrays, it returns NumPy arrays, etc.

In the next cell we import lab.

[1]:
import lab as B

Sometimes, lab on its own does not provide all the functions you need. In that case, check the lab_extras subpackage in GeometricKernels, which contains additional helpers that follow the same backend-dispatch pattern. If you still cannot find a suitable function, consider adding it to lab_extras.

Below we import some utilities we will use:

  • dtype_integer: Returns a backend-appropriate integer dtype.

[2]:
from geometric_kernels.lab_extras import (
    dtype_integer,
)
INFO (geometric_kernels): Numpy backend is enabled. To enable other backends, don't forget to `import geometric_kernels.*backend name*`.
INFO (geometric_kernels): We may be suppressing some logging of external libraries. To override the logging policy, call `logging.basicConfig`.

Next, we import the abstract base class DiscreteSpectrumSpace, because the Laplacian on a finite graph (the graph Laplacian) has a discrete spectrum. See more about different kinds of spaces in the theory docs: https://geometric-kernels.github.io/GeometricKernels/theory/index.html.

[3]:
from geometric_kernels.spaces.base import DiscreteSpectrumSpace

We also import Eigenfunctions because one of the abstract methods of DiscreteSpectrumSpace must return an Eigenfunctions object.

In many cases, you will subclass Eigenfunctions to provide the machinery for working with the Laplacian’s eigenfunctions. In our case, we can avoid that by using the already implemented subclass EigenfunctionsFromEigenvectors, which wraps a matrix of eigenvectors.

[4]:
from geometric_kernels.spaces.eigenfunctions import (
    Eigenfunctions,
    EigenfunctionsFromEigenvectors,
)

Finally, a few standard imports. We’ll use SciPy to compute the full eigendecomposition for this small example.

[5]:
import numpy as np
import scipy.linalg

Defining the new Space Class

Because it is awkward to split a single class definition across multiple Jupyter cells, the complete class definition is given in the next cell.

Note on contributing: If you intend to contribute your new space to GeometricKernels via a PR, please add proper docstrings in the project’s style (these get compiled into the documentation). You can use the implementation of the Graph space as a reference. To keep this notebook concise, we include explanatory comments rather than full docstrings.

[6]:
class CompleteGraph(DiscreteSpectrumSpace):

    def __init__(self, num_vertices: int):
        # Store the number of vertices.
        self.num_vertices = num_vertices

        # Build the adjacency matrix A for the complete graph K_n:
        # A_ij = 1 for i != j, and 0 on the diagonal.
        adjacency_matrix = np.ones((num_vertices, num_vertices), dtype=np.float64) - np.eye(num_vertices, dtype=np.float64)

        # Degree matrix D for K_n is (n-1) on the diagonal.
        degree_matrix = np.diag(num_vertices * [num_vertices - 1])

        # Normalized graph Laplacian: L = D^(-1/2) (D - A) D^(-1/2).
        self._laplacian = (degree_matrix - adjacency_matrix)/(num_vertices - 1)

        # Note: As a rule of thumb, store precomputed quantities in NumPy and cast to the
        # active backend only when needed (usually, kernels handle this casting for you).

        # For this toy example, we compute a full eigendecomposition here. In production code,
        # try to keep __init__ lightweight: compute lazily if possible.
        evals, evecs = scipy.linalg.eigh(self._laplacian)
        self._evals = evals

        # Re-scale eigenvectors so that they are orthonormal with respect to the discrete inner
        # product induced by the graph (here, multiply by sqrt(n)). This is a convenient normalisation
        # for use in Mercer expansions.
        self._evecs = evecs * num_vertices ** 0.5  # change normalization

    def __str__(self):
        # Provide a concise, informative string representation.
        return f"CompleteGraph({self.num_vertices})"

    @property
    def dimension(self) -> int:
        # Return an intrinsic dimension for the space. For graphs, a pragmatic choice is 0.
        # The dimension influences the parameterisation of the Matérn smoothness parameter nu.
        return 0

    def get_eigenfunctions(self, num: int) -> Eigenfunctions:
        # Return the `Eigenfunctions` object representing the first `num` levels of eigenfunctions.
        # The notion of "levels" lets many spaces trade off computation for approximation quality.
        # For this space, any admissible `num` has the same one-off cost to precompute, though
        # subsequent kernel evaluations can be faster for smaller `num`.
        #
        # See more about levels in the docs for the `Eigenfunctions` class.
        return EigenfunctionsFromEigenvectors(self._evecs[:, :num])

    def get_eigenvalues(self, num: int) -> B.Numeric:
        # Return the first `num` eigenvalues.
        return self._evals[:num]

    def get_repeated_eigenvalues(self, num: int) -> B.Numeric:
        # Some DiscreteSpectrumSpace instances group eigenfunctions with multiplicities into "levels".
        # In that case, this should return each level’s eigenvalue repeated according to its multiplicity.
        # Here each level corresponds to a single eigenfunction, so we can just return the first `num`
        # values, exactly as in `get_eigenvalues` above.
        #
        # See more about levels in the docs for the `Eigenfunctions` class.
        return self.get_eigenvalues(num)

    def random(self, key, number):
        # Generate `number` random vertices (indices) from {0, ..., n-1} in a backend-aware way.
        # This method consumes a backend-specific random key (e.g. a torch.Generator or JAX PRNGKey)
        # and returns samples and the updated key in the same backend.
        #
        # Note this is the first function that actually uses `lab`, the rest delegates supporting
        # different backends to kernels.
        key, random_vertices = B.randint(
            key, dtype_integer(key), number, 1, lower=0, upper=self.num_vertices
        )
        return key, random_vertices

    @property
    def element_shape(self):
        # Each point of this space is represented by a single integer node index, thus shape is [1].
        # For R^d you might return [d]; for matrix groups, [d, d], etc. This is used downstream
        # (e.g. for constructing product spaces/kernels).
        return [1]

    @property
    def element_dtype(self):
        # The dtype of elements in this space is integer. Also used downstream.
        return B.Int

Using the new Space

We now demonstrate the space under the torch backend. For this, we enable the backend and import a convenience wrapper kernel MaternGeometricKernel.

[7]:
import torch

import geometric_kernels
import geometric_kernels.torch

from geometric_kernels.kernels import MaternGeometricKernel
INFO (geometric_kernels): Torch backend enabled.

Because we subclassed a core abstract space (DiscreteSpectrumSpace), we can instantiate MaternGeometricKernel on our new space without any additional work.

[8]:
mygraph = CompleteGraph(5)
kernel = MaternGeometricKernel(mygraph)

Initialise the kernel’s parameters. The defaults will come out as NumPy arrays unless you override them.

[9]:
params = kernel.init_params()
print("params:", params)
params: {'nu': array([inf]), 'lengthscale': array([1.])}

To have the kernel operate in PyTorch, provide parameters as PyTorch tensors; the kernel will then dispatch calculations to the torch backend.

[10]:
params["lengthscale"] = torch.tensor([2.0])
params["nu"]  = torch.tensor([3/2])
print("params:", params)
params: {'nu': tensor([1.5000]), 'lengthscale': tensor([2.])}

Generate random inputs for the kernel, also in torch. The random method consumes and returns the backend’s PRNG key (here torch.Generator).

[11]:
key = torch.Generator()
key.manual_seed(1234)

key, xs = mygraph.random(key, 6)

print("xs:", xs)
xs: tensor([[0],
        [1],
        [1],
        [0],
        [1],
        [4]], dtype=torch.int32)

Now evaluate the kernel.

[12]:
print("K:", kernel.K(params, xs, xs))
# Note: By default the kernel is normalised such that the variances sum to 1 across nodes.
# If you prefer marginal variances around 1 for each node, multiply by `num_vertices`.
K: tensor([[0.2000, 0.0803, 0.0803, 0.2000, 0.0803, 0.0803],
        [0.0803, 0.2000, 0.2000, 0.0803, 0.2000, 0.0803],
        [0.0803, 0.2000, 0.2000, 0.0803, 0.2000, 0.0803],
        [0.2000, 0.0803, 0.0803, 0.2000, 0.0803, 0.0803],
        [0.0803, 0.2000, 0.2000, 0.0803, 0.2000, 0.0803],
        [0.0803, 0.0803, 0.0803, 0.0803, 0.0803, 0.2000]])

Kernels

The existing kernel classes cover most use cases; see the kernel docs before rolling your own.

Two common scenarios for a custom kernel are:

  1. You want a different PSD kernel that is still representable via the Laplacian’s spectrum (i.e. you only change the spectral filter from the Matérn/heat form).

  2. You want to reimplement the Matérn kernel for your space more efficiently by exploiting finer structure (e.g. you know a closed form expression for your space).

Custom Kernel for Usecase 1

Below is an example of scenario (1). You can inherit from a default spectral kernel (e.g. MaternKarhunenLoeveKernel for graphs) and override the _spectrum method, which maps eigenvalues to the Mercer coefficients of the kernel. Here we define an inverse cosine kernel on graphs (Graph space).

[13]:
from geometric_kernels.kernels import MaternKarhunenLoeveKernel
from geometric_kernels.spaces import Graph
[14]:
class InverseCosineKernel(MaternKarhunenLoeveKernel):
    def __init__(self, space: Graph, num_levels: int, normalize: bool = True):
        # For graphs, the Karhunen–Loève-based implementation uses num_levels equal to
        # the number of vertices by default; pass a smaller number to truncate.
        super().__init__(space, num_levels, normalize)

    @staticmethod
    def spectrum(
        s: B.Numeric, nu: B.Numeric, lengthscale: B.Numeric, dimension: int
    ) -> B.Numeric:
        # Map eigenvalues s to Mercer coefficients. Ensure this stays non-negative for PSD.
        # **This is only so for the normalized graph Laplacian**.
        # Note: we simply ignore `nu`, `lengthscale` and `dimension` here.
        return B.cos(B.cast(B.dtype(lengthscale), s) * np.pi / 4)

Using the new Kernel

[15]:
inv_cos_kernel = InverseCosineKernel(mygraph, num_levels=mygraph.num_vertices)
print("K:", inv_cos_kernel.K(params, xs, xs))
K: tensor([[0.2000, 0.0276, 0.0276, 0.2000, 0.0276, 0.0276],
        [0.0276, 0.2000, 0.2000, 0.0276, 0.2000, 0.0276],
        [0.0276, 0.2000, 0.2000, 0.0276, 0.2000, 0.0276],
        [0.2000, 0.0276, 0.0276, 0.2000, 0.0276, 0.0276],
        [0.0276, 0.2000, 0.2000, 0.0276, 0.2000, 0.0276],
        [0.0276, 0.0276, 0.0276, 0.0276, 0.0276, 0.2000]])

Custom Kernel for Usecase 2

For scenario (2), you can inherit from BaseGeometricKernel or from a default kernel (e.g. MaternKarhunenLoeveKernel) and implement K and K_diag directly for efficiency.

Below we show how one could, for example, leverage the closed for expression for the heat kernel (Matérn with \(\nu=\infty\)) known for the complete graph.

[16]:
from beartype.typing import Dict, Optional

class MaternKernelCompleteGraph(MaternKarhunenLoeveKernel):
    def __init__(
        self,
        space: CompleteGraph,
        num_levels: int,
        normalize: bool = True,
    ):
        # Call the parent initialiser; override K/K_diag behaviour here if needed.
        super().__init__(space, num_levels, normalize)

    def K(
        self,
        params: Dict[str, Dict[str, B.NPNumeric]],
        X: B.Numeric,
        X2: Optional[B.Numeric] = None,
        **kwargs,
    ) -> B.Numeric:
        # Use the exact heat kernel for nu = inf; otherwise defer to the parent. Assumes a normalized Laplacian.
        #
        # Derivation outline for the exact heat kernel (Matérn with nu=inf) on the complete graph.
        # The form is standard (see, e.g., Kondor and Lafferty, 2002), with adjustments for
        # GeometricKernels’ parameterization:
        #
        # Normalized Laplacian for K_n: L = I - D^{-1/2} A D^{-1/2} = (n/(n-1)) I - (1/(n-1)) J, where J is the all-ones matrix.
        # Spectrum: 0 (mult. 1, eigenvector (1, 1, ...)) and α := n/(n-1) (mult. n-1).
        # Heat kernel with length scale κ is exp(- (κ^2 / 2) L).
        # Let t := κ^2 / 2. Then exp(-t L) = J + e^{-α t} (nI - J).
        # Hence entries: k(x, x) = 1 + (n - 1) e^{-α t},  k(x, x') = 1 - e^{-α t} for x != x'.
        nu = params["nu"]
        if nu < np.inf:  # Note: this may cause problems with JAX autodiff (see how `MaternKarhunenLoeveKernel` handles conditions)
            return super().K(params, X, X2, **kwargs)

        if X2 is None:
            X2 = X

        kappa = params["lengthscale"]
        t = kappa**2 / 2
        n = self.space.num_vertices
        alpha = n/(n-1)

        k_same = 1 + (n - 1)*B.exp(-alpha*t)
        k_diff = 1 - B.exp(-alpha*t)
        eq = B.cast(B.dtype(kappa), B.eq(X, X2[None, :, 0]))  # indicator [x == x']

        K = eq*k_same + (1-eq)*k_diff

        if self.normalize:
            # Scale so that trace(K) == 1 (sum of marginal variances is 1).
            K = K / k_same * (1/n)

        return K

    def K_diag(
        self,
        params: Dict[str, Dict[str, B.NPNumeric]],
        X: B.Numeric,
        X2: Optional[B.Numeric] = None,
        **kwargs,
    ) -> B.Numeric:
        # Normally, you would implement this separately to reduce complexity.
        return self.K(params, X, X2, **kwargs)

Checking the new Kernel

First, \(\nu < \infty\), effectively using the default kernel.

[17]:
alt_kernel = MaternKernelCompleteGraph(mygraph, mygraph.num_vertices)

print("K_alt:\n", alt_kernel.K(params, xs, xs))
print("K_orig:\n", kernel.K(params, xs, xs))
K_alt:
 tensor([[0.2000, 0.0803, 0.0803, 0.2000, 0.0803, 0.0803],
        [0.0803, 0.2000, 0.2000, 0.0803, 0.2000, 0.0803],
        [0.0803, 0.2000, 0.2000, 0.0803, 0.2000, 0.0803],
        [0.2000, 0.0803, 0.0803, 0.2000, 0.0803, 0.0803],
        [0.0803, 0.2000, 0.2000, 0.0803, 0.2000, 0.0803],
        [0.0803, 0.0803, 0.0803, 0.0803, 0.0803, 0.2000]])
K_orig:
 tensor([[0.2000, 0.0803, 0.0803, 0.2000, 0.0803, 0.0803],
        [0.0803, 0.2000, 0.2000, 0.0803, 0.2000, 0.0803],
        [0.0803, 0.2000, 0.2000, 0.0803, 0.2000, 0.0803],
        [0.2000, 0.0803, 0.0803, 0.2000, 0.0803, 0.0803],
        [0.0803, 0.2000, 0.2000, 0.0803, 0.2000, 0.0803],
        [0.0803, 0.0803, 0.0803, 0.0803, 0.0803, 0.2000]])

Next, \(\nu = \infty\), using the new code.

[18]:
alt_params = params.copy()
alt_params["nu"] =  torch.tensor([np.inf])

print("K_alt:\n", alt_kernel.K(alt_params, xs, xs))
print("K_orig:\n", kernel.K(alt_params, xs, xs))
K_alt:
 tensor([[0.2000, 0.1382, 0.1382, 0.2000, 0.1382, 0.1382],
        [0.1382, 0.2000, 0.2000, 0.1382, 0.2000, 0.1382],
        [0.1382, 0.2000, 0.2000, 0.1382, 0.2000, 0.1382],
        [0.2000, 0.1382, 0.1382, 0.2000, 0.1382, 0.1382],
        [0.1382, 0.2000, 0.2000, 0.1382, 0.2000, 0.1382],
        [0.1382, 0.1382, 0.1382, 0.1382, 0.1382, 0.2000]])
K_orig:
 tensor([[0.2000, 0.1382, 0.1382, 0.2000, 0.1382, 0.1382],
        [0.1382, 0.2000, 0.2000, 0.1382, 0.2000, 0.1382],
        [0.1382, 0.2000, 0.2000, 0.1382, 0.2000, 0.1382],
        [0.2000, 0.1382, 0.1382, 0.2000, 0.1382, 0.1382],
        [0.1382, 0.2000, 0.2000, 0.1382, 0.2000, 0.1382],
        [0.1382, 0.1382, 0.1382, 0.1382, 0.1382, 0.2000]])
[ ]: