geometric_kernels.frontends.gpjax

GPJax kernel wrapper.

A tutorial on how to use this wrapper to run Gaussian process regression on a geometric space is available in the frontends/GPJax.ipynb notebook.

Module Contents

class geometric_kernels.frontends.gpjax.GPJaxGeometricKernel

Bases: gpjax.kernels.AbstractKernel

GPJax wrapper for BaseGeometricKernel.

A tutorial on how to use this wrapper to run Gaussian process regression on a geometric space is available in the frontends/GPJax.ipynb notebook.

Note

Remember that the base_kernel itself does not store any of its hyperparameters (like lengthscale and nu). If you do not set them manually—when initializing the object or after, by setting the properties—this wrapper will use the values provided by base_kernel.init_params.

Note

Unlike the frontends for GPflow and GPyTorch, GPJaxGeometricKernel does not have the trainable_nu parameter which determines whether or not the smoothness parameter nu is to be optimized over. By default, it is not trainable. If you want to make it trainable, do kernel = kernel.replace_trainable(nu=False) on an instance of the GPJaxGeometricKernel.

Parameters:
  • base_kernel (geometric_kernels.kernels.BaseGeometricKernel) – The kernel to wrap.

  • name (str) –

    Optional kernel name (inherited from gpjax.kernels.AbstractKernel).

    Defaults to “Geometric Kernel”.

  • lengthscale (Union[ScalarFloat, Float[Array, " D"]]) –

    Initial value of the length scale.

    If not given or set to None, uses the default value of the base_kernel, as provided by its init_params method.

  • nu (ScalarFloat) –

    Initial value of the smoothness parameter nu.

    If not given or set to None, uses the default value of the base_kernel, as provided by its init_params method.

  • variance (ScalarFloat) –

    Initial value of the variance (outputscale) parameter.

    Defaults to 1.0.

__call__(x, y)

Compute the cross-covariance matrix between two batches of vectors (or batches of matrices) of inputs.

Parameters:
  • x (jaxtyping.Num[gpjax.typing.Array, N #D1 D2]) – A batch of N inputs, each of which is a matrix of size D1xD2, or a vector of size D2 if D1 is absent.

  • y (jaxtyping.Num[gpjax.typing.Array, M #D1 D2]) – A batch of M inputs, each of which is a matrix of size D1xD2, or a vector of size D2 if D1 is absent.

Returns:

The N x M cross-covariance matrix.

Return type:

jaxtyping.Float[gpjax.typing.Array, N M]

property space: beartype.typing.Union[geometric_kernels.spaces.Space, beartype.typing.List[geometric_kernels.spaces.Space]]

Alias to the base_kernels space property.

Return type:

beartype.typing.Union[geometric_kernels.spaces.Space, beartype.typing.List[geometric_kernels.spaces.Space]]