Coverage for geometric_kernels/kernels/base.py: 79%

19 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`BaseGeometricKernel` kernel, the base class 

3for all geometric kernels. 

4""" 

5 

6import abc 

7 

8import lab as B 

9from beartype.typing import Dict, List, Optional, Union 

10 

11from geometric_kernels.spaces import Space 

12 

13 

14class BaseGeometricKernel(abc.ABC): 

15 """ 

16 Abstract base class for geometric kernels. 

17 

18 :param space: 

19 The space on which the kernel is defined. 

20 """ 

21 

22 def __init__(self, space: Space): 

23 self._space = space 

24 

25 @property 

26 def space(self) -> Union[Space, List[Space]]: 

27 """ 

28 The space on which the kernel is defined. 

29 """ 

30 return self._space 

31 

32 @abc.abstractmethod 

33 def init_params(self) -> Dict[str, B.NPNumeric]: 

34 """ 

35 Initializes the dict of the trainable parameters of the kernel. 

36 

37 It typically contains only two keys: `"nu"` and `"lengthscale"`. 

38 

39 This dict can be modified and is passed around into such methods as 

40 :meth:`~.K` or :meth:`~.K_diag`, as the `params` argument. 

41 

42 .. note:: 

43 The values in the returned dict are always of the NumPy array type. 

44 Thus, if you want to use some other backend for internal 

45 computations when calling :meth:`~.K` or :meth:`~.K_diag`, you 

46 need to replace the values with the analogs typed as arrays of 

47 the desired backend. 

48 """ 

49 raise NotImplementedError 

50 

51 @abc.abstractmethod 

52 def K( 

53 self, 

54 params: Dict[str, B.Numeric], 

55 X: B.Numeric, 

56 X2: Optional[B.Numeric] = None, 

57 **kwargs, 

58 ) -> B.Numeric: 

59 """ 

60 Compute the cross-covariance matrix between two batches of vectors of 

61 inputs, or batches of matrices of inputs, depending on the space. 

62 

63 :param params: 

64 A dict of kernel parameters, typically containing two keys: 

65 `"lengthscale"` for length scale and `"nu"` for smoothness. 

66 

67 The types of values in the params dict determine the output type 

68 and the backend used for the internal computations, see the 

69 warning below for more details. 

70 

71 .. note:: 

72 The values `params["lengthscale"]` and `params["nu"]` are 

73 typically (1,)-shaped arrays of the suitable backend. This 

74 serves to point at the backend to be used for internal 

75 computations. 

76 

77 In some cases, for example, when the kernel is 

78 :class:`~.kernels.ProductGeometricKernel`, the values of 

79 `params` may be (s,)-shaped arrays instead, where `s` is the 

80 number of factors. 

81 

82 .. note:: 

83 Finite values of `params["nu"]` typically correspond to the 

84 generalized (geometric) Matérn kernels. 

85 

86 Infinite `params["nu"]` typically corresponds to the heat 

87 kernel (a.k.a. diffusion kernel, generalized squared 

88 exponential kernel, generalized Gaussian kernel, 

89 generalized RBF kernel). Although it is often considered to be 

90 a separate entity, we treat the heat kernel as a member of 

91 the Matérn family, with smoothness parameter equal to infinity. 

92 

93 :param X: 

94 A batch of N inputs, each of which is a vector or a matrix, 

95 depending on how the elements of the `self.space` are represented. 

96 :param X2: 

97 A batch of M inputs, each of which is a vector or a matrix, 

98 depending on how the elements of the `self.space` are represented. 

99 

100 `X2=None` sets `X2=X1`. 

101 

102 Defaults to None. 

103 

104 :return: 

105 The N x M cross-covariance matrix. 

106 

107 .. warning:: 

108 The types of values in the `params` dict determine the backend 

109 used for internal computations and the output type. 

110 

111 Even if, say, `geometric_kernels.jax` is imported but the values in 

112 the `params` dict are NumPy arrays, the output type will be a NumPy 

113 array, and NumPy will be used for internal computations. To get a 

114 JAX array as an output and use JAX for internal computations, all 

115 the values in the `params` dict must be JAX arrays. 

116 """ 

117 raise NotImplementedError 

118 

119 @abc.abstractmethod 

120 def K_diag(self, params: Dict[str, B.Numeric], X: B.Numeric, **kwargs) -> B.Numeric: 

121 """ 

122 Returns the diagonal of the covariance matrix `self.K(params, X, X)`, 

123 typically in a more efficient way than actually computing the full 

124 covariance matrix with `self.K(params, X, X)` and then extracting its 

125 diagonal. 

126 

127 :param params: 

128 Same as for :meth:`~.K`. 

129 

130 :param X: 

131 A batch of N inputs, each of which is a vector or a matrix, 

132 depending on how the elements of the `self.space` are represented. 

133 

134 :return: 

135 The N-dimensional vector representing the diagonal of the 

136 covariance matrix `self.K(params, X, X)`. 

137 """ 

138 raise NotImplementedError