Coverage for geometric_kernels/kernels/product.py: 55%

60 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`ProductGeometricKernel` kernel for 

3constructing product kernels from a sequence of kernels. 

4 

5See :doc:`this page </theory/product_kernels>` for a brief account on 

6theory behind product kernels and the :doc:`Torus.ipynb </examples/Torus>` 

7notebook for a tutorial on how to use them. 

8""" 

9 

10import math 

11 

12import lab as B 

13from beartype.typing import Dict, List, Optional 

14 

15from geometric_kernels.kernels.base import BaseGeometricKernel 

16from geometric_kernels.spaces import Space 

17from geometric_kernels.utils.product import params_to_params_list, project_product 

18 

19 

20class ProductGeometricKernel(BaseGeometricKernel): 

21 r""" 

22 Product kernel, defined as the product of a sequence of kernels. 

23 

24 See :doc:`this page </theory/product_kernels>` for a brief account on 

25 theory behind product kernels and the :doc:`Torus.ipynb </examples/Torus>` 

26 notebook for a tutorial on how to use them. 

27 

28 :param ``*kernels``: 

29 A sequence of kernels to compute the product of. Cannot contain another 

30 instance of :class:`ProductGeometricKernel`. We denote the number of 

31 factors, i.e. the length of the "sequence", by s. 

32 :param dimension_indices: 

33 Determines how a product kernel input vector `x` is to be mapped into 

34 the inputs `xi` for the factor kernels. `xi` are assumed to be equal to 

35 `x[dimension_indices[i]]`, possibly up to a reshape. Such a reshape 

36 might be necessary to accommodate the spaces whose elements are matrices 

37 rather than vectors, as determined by `element_shapes`. The 

38 transformation of `x` into the list of `xi`\ s is performed 

39 by :func:`~.project_product`. 

40 

41 If None, assumes the each input is layed-out flattened and concatenated, 

42 in the same order as the factor spaces. In this case, the inverse to 

43 :func:`~.project_product` is :func:`~.make_product`. 

44 

45 Defaults to None. 

46 

47 .. note:: 

48 `params` of a :class:`ProductGeometricKernel` are such that 

49 `params["lengthscale"]` and `params["nu"]` are (s,)-shaped arrays, where 

50 `s` is the number of factors. 

51 

52 Basically, `params["lengthscale"][i]` stores the length scale parameter 

53 for the `i`-th factor kernel. Same goes for `params["nu"]`. Importantly, 

54 this enables *automatic relevance determination*-like behavior. 

55 """ 

56 

57 def __init__( 

58 self, 

59 *kernels: BaseGeometricKernel, 

60 dimension_indices: Optional[List[List[int]]] = None, 

61 ): 

62 self.kernels = kernels 

63 self.spaces: List[Space] = [] 

64 for kernel in self.kernels: 

65 # Make sure there is no product kernel in the list of kernels. 

66 if not isinstance(kernel.space, Space): # as opposed to List[Space] 

67 raise NotImplementedError( 

68 "One of the provided kernels is a product kernel itself." 

69 ) 

70 self.spaces.append(kernel.space) 

71 self.element_shapes = [space.element_shape for space in self.spaces] 

72 self.element_dtypes = [space.element_dtype for space in self.spaces] 

73 

74 if dimension_indices is None: 

75 dimensions = [math.prod(shape) for shape in self.element_shapes] 

76 self.dimension_indices: List[List[int]] = [] 

77 i = 0 

78 inds = [*range(sum(dimensions))] 

79 for dim in dimensions: 

80 self.dimension_indices.append(inds[i : i + dim]) 

81 i += dim 

82 else: 

83 if len(dimension_indices) != len(self.kernels): 

84 raise ValueError( 

85 f"`dimension_indices` must correspond to `kernels`, but got {len(kernels)} kernels and {len(dimension_indices)} dimension indices." 

86 ) 

87 for idx_list in dimension_indices: 

88 for idx in idx_list: 

89 if idx < 0: 

90 raise ValueError( 

91 "Expected all `dimension_indices` to be non-negative." 

92 ) 

93 

94 self.dimension_indices = dimension_indices 

95 

96 @property 

97 def space(self) -> List[Space]: 

98 """ 

99 The list of spaces upon which the factor kernels are defined. 

100 """ 

101 return self.spaces 

102 

103 def init_params(self) -> Dict[str, B.NPNumeric]: 

104 r""" 

105 Returns a dict `params` where `params["lengthscale"]` is the 

106 concatenation of all `self.kernels[i].init_params()["lengthscale"]` and 

107 same for `params["nu"]`. 

108 """ 

109 nu_list: List[B.NPNumeric] = [] 

110 lengthscale_list: List[B.NPNumeric] = [] 

111 

112 for kernel_idx, kernel in enumerate(self.kernels): 

113 cur_params = kernel.init_params() 

114 if B.shape(cur_params["lengthscale"]) != (1,): 

115 raise ValueError( 

116 f"All kernels' `lengthscale`s must be have shape [1,], but {kernel_idx}th kernel " 

117 f"({kernel}) violates this with shape {B.shape(cur_params['lengthscale'])}." 

118 ) 

119 if B.shape(cur_params["nu"]) != (1,): 

120 raise ValueError( 

121 f"All kernels' `nu`s must be have [1,], but {kernel_idx}th kernel " 

122 f"({kernel}) violates this with shape {B.shape(cur_params['nu'])}." 

123 ) 

124 

125 nu_list.append(cur_params["nu"]) 

126 lengthscale_list.append(cur_params["lengthscale"]) 

127 

128 params: Dict[str, B.NPNumeric] = {} 

129 params["nu"] = B.concat(*nu_list) 

130 params["lengthscale"] = B.concat(*lengthscale_list) 

131 return params 

132 

133 def K(self, params: Dict[str, B.Numeric], X, X2=None, **kwargs) -> B.Numeric: 

134 if X2 is None: 

135 X2 = X 

136 

137 Xs = project_product( 

138 X, self.dimension_indices, self.element_shapes, self.element_dtypes 

139 ) 

140 X2s = project_product( 

141 X2, self.dimension_indices, self.element_shapes, self.element_dtypes 

142 ) 

143 params_list = params_to_params_list(len(self.kernels), params) 

144 

145 return B.prod( 

146 B.stack( 

147 *[ 

148 kernel.K(p, X, X2) 

149 for kernel, X, X2, p in zip(self.kernels, Xs, X2s, params_list) 

150 ], 

151 axis=-1, 

152 ), 

153 axis=-1, 

154 ) 

155 

156 def K_diag(self, params, X): 

157 Xs = project_product( 

158 X, self.dimension_indices, self.element_shapes, self.element_dtypes 

159 ) 

160 params_list = params_to_params_list(len(self.kernels), params) 

161 

162 return B.prod( 

163 B.stack( 

164 *[ 

165 kernel.K_diag(p, X) 

166 for kernel, X, p in zip(self.kernels, Xs, params_list) 

167 ], 

168 axis=-1, 

169 ), 

170 axis=-1, 

171 )