Coverage for geometric_kernels/utils/product.py: 71%

24 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1"""Utilities for dealing with product spaces and product kernels.""" 

2 

3import lab as B 

4from beartype.typing import Dict, List 

5 

6from geometric_kernels.lab_extras import smart_cast 

7from geometric_kernels.utils.utils import _check_rank_1_array 

8 

9 

10def params_to_params_list( 

11 number_of_factors: int, params: Dict[str, B.Numeric] 

12) -> List[Dict[str, B.Numeric]]: 

13 """ 

14 Takes a dictionary of parameters of a product kernel and returns a list of 

15 dictionaries of parameters for the factor kernels. The shape of "lengthscale" 

16 should be the same as the shame of "nu", and the length of both should be 

17 either 1 or equal to `number_of_factors`. 

18 

19 :param number_of_factors: 

20 Number of factors in the product kernel. 

21 :param params: 

22 Parameters of the product kernel. 

23 """ 

24 if B.shape(params["lengthscale"]) != B.shape(params["nu"]): 

25 raise ValueError( 

26 'Shape mismatch between `params["lengthscale"]` and `params["nu"].`' 

27 ) 

28 

29 _check_rank_1_array(params["nu"], 'params["nu"]') 

30 

31 if params["nu"].shape[0] == 1: 

32 return [params] * number_of_factors 

33 

34 if B.shape(params["nu"])[0] != number_of_factors: 

35 raise ValueError( 

36 "Shapes of the kernel parameters `lengthscale`, `nu` must be [`number_of_factors`]." 

37 ) 

38 

39 list_of_params: List[Dict[str, B.Numeric]] = [] 

40 for i in range(number_of_factors): 

41 list_of_params.append( 

42 { 

43 "lengthscale": params["lengthscale"][i : i + 1], 

44 "nu": params["nu"][i : i + 1], 

45 } 

46 ) 

47 

48 return list_of_params 

49 

50 

51def make_product(xs: List[B.Numeric]) -> B.Numeric: 

52 """ 

53 Embed a list of elements of factor spaces into the product space. 

54 Assumes that elements are batched along the first dimension. 

55 

56 :param xs: 

57 List of the batches of elements, each of the shape [N, <axes(space)>], 

58 where `<axes(space)>` is the shape of the elements of the respective 

59 space. 

60 

61 :return: 

62 An [N, D]-shaped array, a batch of product space elements, where `D` is 

63 the sum, over all factor spaces, of `prod(<axes(space)>)`. 

64 """ 

65 common_dtype = B.promote_dtypes(*[B.dtype(x) for x in xs]) 

66 

67 flat_xs = [B.cast(common_dtype, B.reshape(x, B.shape(x)[0], -1)) for x in xs] 

68 return B.concat(*flat_xs, axis=-1) 

69 

70 

71def project_product( 

72 x: B.Numeric, 

73 dimension_indices: List[List[int]], 

74 element_shapes: List[List[int]], 

75 element_dtypes: List[B.DType], 

76) -> List[B.Numeric]: 

77 """ 

78 Project an element of the product space onto each factor. 

79 Assumes that elements are batched along the first dimension. 

80 

81 :param x: 

82 An [N, D]-shaped array, a batch of N product space elements. 

83 :param dimension_indices: 

84 Determines how a product space element `x` is to be mapped to inputs 

85 `xi` of the factor kernels. `xi` are assumed to be equal to 

86 `x[dimension_indices[i]]`, possibly up to a reshape. Such a reshape 

87 might be necessary to accommodate the spaces whose elements are matrices 

88 rather than vectors, as determined by `element_shapes`. 

89 :param element_shapes: 

90 Shapes of the elements in each factor. Can be obtained as properties 

91 `space.element_shape` of any given factor `space`. 

92 :param element_dtypes: 

93 Abstract lab data types of the elements in each factor. Can be obtained 

94 as properties `space.element_dtype` of any given factor `space`. 

95 

96 :return: 

97 A list of the batches of elements `xi` in factor spaces, each of the 

98 shape `[N, *element_shapes[i]]`. 

99 """ 

100 N = x.shape[0] 

101 xs = [ 

102 smart_cast(dtype, B.reshape(B.take(x, inds, axis=-1), N, *shape)) 

103 for inds, shape, dtype in zip(dimension_indices, element_shapes, element_dtypes) 

104 ] 

105 return xs