geometric_kernels.frontends.gpjax

GPJax kernel wrapper.

A tutorial on how to use this wrapper to run Gaussian process regression on a geometric space is available in the frontends/GPJax.ipynb notebook.

Module Contents

class geometric_kernels.frontends.gpjax.GPJaxGeometricKernel(base_kernel, lengthscale=None, nu=None, variance=1.0, trainable_nu=False)[source]

Bases: gpjax.kernels.AbstractKernel

GPJax wrapper for BaseGeometricKernel.

A tutorial on how to use this wrapper to run Gaussian process regression on a geometric space is available in the frontends/GPJax.ipynb notebook.

Note

Remember that the base_kernel itself does not store any of its hyperparameters (like lengthscale and nu). If you do not set them manually—when initializing the object or after, by setting the properties—this wrapper will use the values provided by base_kernel.init_params.

Parameters:
  • base_kernel (geometric_kernels.kernels.BaseGeometricKernel) – The kernel to wrap.

  • lengthscale (beartype.typing.Union[beartype.typing.Union[gpjax.typing.ScalarFloat, jaxtyping.Float[gpjax.typing.Array, D]], flax.nnx.Variable[beartype.typing.Union[gpjax.typing.ScalarFloat, jaxtyping.Float[gpjax.typing.Array, D]]], None]) –

    Initial value of the length scale.

    If not given or set to None, uses the default value of the base_kernel, as provided by its init_params method.

  • nu (beartype.typing.Union[gpjax.typing.ScalarFloat, flax.nnx.Variable[gpjax.typing.ScalarFloat], None]) –

    Initial value of the smoothness parameter nu.

    If not given or set to None, uses the default value of the base_kernel, as provided by its init_params method.

  • variance (beartype.typing.Union[gpjax.typing.ScalarFloat, flax.nnx.Variable[gpjax.typing.ScalarFloat]]) –

    Initial value of the variance (outputscale) parameter.

    Defaults to 1.0.

  • trainable_nu (bool) –

    Whether or not the parameter nu is to be optimized over.

    You must not change this parameter after constructing the object. Defaults to False.

__call__(x, y)[source]

Compute the cross-covariance matrix between two batches of vectors (or batches of matrices) of inputs.

Parameters:
  • x (jaxtyping.Num[gpjax.typing.Array, N #D1 D2]) – A batch of N inputs, each of which is a matrix of size D1xD2, or a vector of size D2 if D1 is absent.

  • y (jaxtyping.Num[gpjax.typing.Array, M #D1 D2]) – A batch of M inputs, each of which is a matrix of size D1xD2, or a vector of size D2 if D1 is absent.

Returns:

The N x M cross-covariance matrix.

Return type:

jaxtyping.Float[gpjax.typing.Array, N M]

__init__(base_kernel, lengthscale=None, nu=None, variance=1.0, trainable_nu=False)[source]

Initialise the AbstractKernel class.

Args:
active_dims: the indices of the input dimensions

that are active in the kernel’s evaluation, represented by a list of integers or a slice object. Defaults to a full slice.

n_dims: the number of input dimensions of the kernel. compute_engine: the computation engine that is used to compute the kernel’s

cross-covariance and gram matrices. Defaults to DenseKernelComputation.

Parameters:
  • base_kernel (geometric_kernels.kernels.BaseGeometricKernel)

  • lengthscale (beartype.typing.Union[beartype.typing.Union[gpjax.typing.ScalarFloat, jaxtyping.Float[gpjax.typing.Array, D]], flax.nnx.Variable[beartype.typing.Union[gpjax.typing.ScalarFloat, jaxtyping.Float[gpjax.typing.Array, D]]], None])

  • nu (beartype.typing.Union[gpjax.typing.ScalarFloat, flax.nnx.Variable[gpjax.typing.ScalarFloat], None])

  • variance (beartype.typing.Union[gpjax.typing.ScalarFloat, flax.nnx.Variable[gpjax.typing.ScalarFloat]])

  • trainable_nu (bool)

property space: beartype.typing.Union[geometric_kernels.spaces.Space, beartype.typing.List[geometric_kernels.spaces.Space]]

Alias to the base_kernels space property.

Return type:

beartype.typing.Union[geometric_kernels.spaces.Space, beartype.typing.List[geometric_kernels.spaces.Space]]