Coverage for geometric_kernels/utils/product.py: 71%
24 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""Utilities for dealing with product spaces and product kernels."""
3import lab as B
4from beartype.typing import Dict, List
6from geometric_kernels.lab_extras import smart_cast
7from geometric_kernels.utils.utils import _check_rank_1_array
10def params_to_params_list(
11 number_of_factors: int, params: Dict[str, B.Numeric]
12) -> List[Dict[str, B.Numeric]]:
13 """
14 Takes a dictionary of parameters of a product kernel and returns a list of
15 dictionaries of parameters for the factor kernels. The shape of "lengthscale"
16 should be the same as the shame of "nu", and the length of both should be
17 either 1 or equal to `number_of_factors`.
19 :param number_of_factors:
20 Number of factors in the product kernel.
21 :param params:
22 Parameters of the product kernel.
23 """
24 if B.shape(params["lengthscale"]) != B.shape(params["nu"]):
25 raise ValueError(
26 'Shape mismatch between `params["lengthscale"]` and `params["nu"].`'
27 )
29 _check_rank_1_array(params["nu"], 'params["nu"]')
31 if params["nu"].shape[0] == 1:
32 return [params] * number_of_factors
34 if B.shape(params["nu"])[0] != number_of_factors:
35 raise ValueError(
36 "Shapes of the kernel parameters `lengthscale`, `nu` must be [`number_of_factors`]."
37 )
39 list_of_params: List[Dict[str, B.Numeric]] = []
40 for i in range(number_of_factors):
41 list_of_params.append(
42 {
43 "lengthscale": params["lengthscale"][i : i + 1],
44 "nu": params["nu"][i : i + 1],
45 }
46 )
48 return list_of_params
51def make_product(xs: List[B.Numeric]) -> B.Numeric:
52 """
53 Embed a list of elements of factor spaces into the product space.
54 Assumes that elements are batched along the first dimension.
56 :param xs:
57 List of the batches of elements, each of the shape [N, <axes(space)>],
58 where `<axes(space)>` is the shape of the elements of the respective
59 space.
61 :return:
62 An [N, D]-shaped array, a batch of product space elements, where `D` is
63 the sum, over all factor spaces, of `prod(<axes(space)>)`.
64 """
65 common_dtype = B.promote_dtypes(*[B.dtype(x) for x in xs])
67 flat_xs = [B.cast(common_dtype, B.reshape(x, B.shape(x)[0], -1)) for x in xs]
68 return B.concat(*flat_xs, axis=-1)
71def project_product(
72 x: B.Numeric,
73 dimension_indices: List[List[int]],
74 element_shapes: List[List[int]],
75 element_dtypes: List[B.DType],
76) -> List[B.Numeric]:
77 """
78 Project an element of the product space onto each factor.
79 Assumes that elements are batched along the first dimension.
81 :param x:
82 An [N, D]-shaped array, a batch of N product space elements.
83 :param dimension_indices:
84 Determines how a product space element `x` is to be mapped to inputs
85 `xi` of the factor kernels. `xi` are assumed to be equal to
86 `x[dimension_indices[i]]`, possibly up to a reshape. Such a reshape
87 might be necessary to accommodate the spaces whose elements are matrices
88 rather than vectors, as determined by `element_shapes`.
89 :param element_shapes:
90 Shapes of the elements in each factor. Can be obtained as properties
91 `space.element_shape` of any given factor `space`.
92 :param element_dtypes:
93 Abstract lab data types of the elements in each factor. Can be obtained
94 as properties `space.element_dtype` of any given factor `space`.
96 :return:
97 A list of the batches of elements `xi` in factor spaces, each of the
98 shape `[N, *element_shapes[i]]`.
99 """
100 N = x.shape[0]
101 xs = [
102 smart_cast(dtype, B.reshape(B.take(x, inds, axis=-1), N, *shape))
103 for inds, shape, dtype in zip(dimension_indices, element_shapes, element_dtypes)
104 ]
105 return xs