Source code for geometric_kernels.utils.product
"""Utilities for dealing with product spaces and product kernels."""
import lab as B
from beartype.typing import Dict, List
from geometric_kernels.lab_extras import smart_cast
from geometric_kernels.utils.utils import _check_rank_1_array
[docs]
def params_to_params_list(
number_of_factors: int, params: Dict[str, B.Numeric]
) -> List[Dict[str, B.Numeric]]:
"""
Takes a dictionary of parameters of a product kernel and returns a list of
dictionaries of parameters for the factor kernels. The shape of "lengthscale"
should be the same as the shame of "nu", and the length of both should be
either 1 or equal to `number_of_factors`.
:param number_of_factors:
Number of factors in the product kernel.
:param params:
Parameters of the product kernel.
"""
if B.shape(params["lengthscale"]) != B.shape(params["nu"]):
raise ValueError(
'Shape mismatch between `params["lengthscale"]` and `params["nu"].`'
)
_check_rank_1_array(params["nu"], 'params["nu"]')
if params["nu"].shape[0] == 1:
return [params] * number_of_factors
if B.shape(params["nu"])[0] != number_of_factors:
raise ValueError(
"Shapes of the kernel parameters `lengthscale`, `nu` must be [`number_of_factors`]."
)
list_of_params: List[Dict[str, B.Numeric]] = []
for i in range(number_of_factors):
list_of_params.append(
{
"lengthscale": params["lengthscale"][i : i + 1],
"nu": params["nu"][i : i + 1],
}
)
return list_of_params
[docs]
def make_product(xs: List[B.Numeric]) -> B.Numeric:
"""
Embed a list of elements of factor spaces into the product space.
Assumes that elements are batched along the first dimension.
:param xs:
List of the batches of elements, each of the shape [N, <axes(space)>],
where `<axes(space)>` is the shape of the elements of the respective
space.
:return:
An [N, D]-shaped array, a batch of product space elements, where `D` is
the sum, over all factor spaces, of `prod(<axes(space)>)`.
"""
common_dtype = B.promote_dtypes(*[B.dtype(x) for x in xs])
flat_xs = [B.cast(common_dtype, B.reshape(x, B.shape(x)[0], -1)) for x in xs]
return B.concat(*flat_xs, axis=-1)
[docs]
def project_product(
x: B.Numeric,
dimension_indices: List[List[int]],
element_shapes: List[List[int]],
element_dtypes: List[B.DType],
) -> List[B.Numeric]:
"""
Project an element of the product space onto each factor.
Assumes that elements are batched along the first dimension.
:param x:
An [N, D]-shaped array, a batch of N product space elements.
:param dimension_indices:
Determines how a product space element `x` is to be mapped to inputs
`xi` of the factor kernels. `xi` are assumed to be equal to
`x[dimension_indices[i]]`, possibly up to a reshape. Such a reshape
might be necessary to accommodate the spaces whose elements are matrices
rather than vectors, as determined by `element_shapes`.
:param element_shapes:
Shapes of the elements in each factor. Can be obtained as properties
`space.element_shape` of any given factor `space`.
:param element_dtypes:
Abstract lab data types of the elements in each factor. Can be obtained
as properties `space.element_dtype` of any given factor `space`.
:return:
A list of the batches of elements `xi` in factor spaces, each of the
shape `[N, *element_shapes[i]]`.
"""
N = x.shape[0]
xs = [
smart_cast(dtype, B.reshape(B.take(x, inds, axis=-1), N, *shape))
for inds, shape, dtype in zip(dimension_indices, element_shapes, element_dtypes)
]
return xs