Coverage for geometric_kernels/feature_maps/rejection_sampling.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`RejectionSamplingFeatureMapHyperbolic` and the 

3:class:`RejectionSamplingFeatureMapSPD`, rejection sampling-based feature maps 

4for :class:`~.spaces.Hyperbolic` and 

5:class:`~.spaces.SymmetricPositiveDefiniteMatrices`, respectively. 

6""" 

7 

8import lab as B 

9from beartype.typing import Dict, Tuple 

10 

11from geometric_kernels.feature_maps.base import FeatureMap 

12from geometric_kernels.feature_maps.probability_densities import ( 

13 hyperbolic_density_sample, 

14 spd_density_sample, 

15) 

16from geometric_kernels.lab_extras import from_numpy 

17from geometric_kernels.spaces import Hyperbolic, SymmetricPositiveDefiniteMatrices 

18 

19 

20class RejectionSamplingFeatureMapHyperbolic(FeatureMap): 

21 """ 

22 Random phase feature map for the :class:`~.spaces.Hyperbolic` space based 

23 on the rejection sampling algorithm. 

24 

25 :param space: 

26 A :class:`~.spaces.Hyperbolic` space. 

27 :param num_random_phases: 

28 Number of random phases to use. 

29 :param shifted_laplacian: 

30 If True, assumes that the kernels are defined in terms of the shifted 

31 Laplacian. This often makes Matérn kernels more flexible by widening 

32 the effective range of the length scale parameter. 

33 

34 Defaults to True. 

35 """ 

36 

37 def __init__( 

38 self, 

39 space: Hyperbolic, 

40 num_random_phases: int = 3000, 

41 shifted_laplacian: bool = True, 

42 ): 

43 self.space = space 

44 self.num_random_phases = num_random_phases 

45 self.shifted_laplacian = shifted_laplacian 

46 

47 def __call__( 

48 self, 

49 X: B.Numeric, 

50 params: Dict[str, B.Numeric], 

51 *, 

52 key: B.RandomState, 

53 normalize: bool = True, 

54 **kwargs, 

55 ) -> Tuple[B.RandomState, B.Numeric]: 

56 """ 

57 :param X: 

58 [N, D] points in the space to evaluate the map on. 

59 :param params: 

60 Parameters of the feature map (length scale and smoothness). 

61 :param key: 

62 Random state, either `np.random.RandomState`, 

63 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which 

64 represents a random state). 

65 

66 .. note:: 

67 For any backend other than `jax`, passing the same `key` twice 

68 does not guarantee that the feature map will be the same each 

69 time. This is because these backends' random state has... a 

70 state. To evaluate the same (including randomness) feature map 

71 on different inputs, you can either save/restore state manually 

72 each time or use the helper function 

73 :func:`~.utils.make_deterministic` which does this for you. 

74 

75 :param normalize: 

76 Normalize to have unit average variance (`True` by default). 

77 :param ``**kwargs``: 

78 Unused. 

79 

80 :return: 

81 `Tuple(key, features)` where `features` is an [N, O] array, N 

82 is the number of inputs and O is the dimension of the feature map; 

83 `key` is the updated random key for `jax`, or the similar random 

84 state (generator) for any other backends. 

85 """ 

86 

87 key, random_phases = self.space.random_phases( 

88 key, self.num_random_phases 

89 ) # [O, D] 

90 

91 key, random_lambda = hyperbolic_density_sample( 

92 key, 

93 (self.num_random_phases, B.rank(self.space.rho)), 

94 params, 

95 self.space.dimension, 

96 self.shifted_laplacian, 

97 ) # [O, 1] 

98 

99 random_phases_b = B.expand_dims( 

100 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_phases)) 

101 ) # [1, O, D] 

102 random_lambda_b = B.expand_dims( 

103 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_lambda)) 

104 ) # [1, O, 1] 

105 X_b = B.expand_dims(X, axis=-2) # [N, 1, D] 

106 

107 p = self.space.power_function(random_lambda_b, X_b, random_phases_b) # [N, O] 

108 

109 features = B.concat(B.real(p), B.imag(p), axis=-1) # [N, 2*O] 

110 if normalize: 

111 normalizer = B.sqrt(B.sum(features**2, axis=-1, squeeze=False)) 

112 features = features / normalizer 

113 

114 return key, features 

115 

116 

117class RejectionSamplingFeatureMapSPD(FeatureMap): 

118 """ 

119 Random phase feature map for the 

120 :class:`~.spaces.SymmetricPositiveDefiniteMatrices` space based on the 

121 rejection sampling algorithm. 

122 

123 :param space: 

124 A :class:`~.spaces.SymmetricPositiveDefiniteMatrices` space. 

125 :param num_random_phases: 

126 Number of random phases to use. 

127 :param shifted_laplacian: 

128 If True, assumes that the kernels are defined in terms of the shifted 

129 Laplacian. This often makes Matérn kernels more flexible by widening 

130 the effective range of the length scale parameter. 

131 

132 Defaults to True. 

133 """ 

134 

135 def __init__( 

136 self, 

137 space: SymmetricPositiveDefiniteMatrices, 

138 num_random_phases: int = 3000, 

139 shifted_laplacian: bool = True, 

140 ): 

141 self.space = space 

142 self.num_random_phases = num_random_phases 

143 self.shifted_laplacian = shifted_laplacian 

144 

145 def __call__( 

146 self, 

147 X: B.Numeric, 

148 params: Dict[str, B.Numeric], 

149 *, 

150 key: B.RandomState, 

151 normalize: bool = True, 

152 **kwargs, 

153 ) -> Tuple[B.RandomState, B.Numeric]: 

154 """ 

155 :param X: 

156 [N, D, D] points in the space to evaluate the map on. 

157 :param params: 

158 Parameters of the feature map (length scale and smoothness). 

159 :param key: 

160 Random state, either `np.random.RandomState`, 

161 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which 

162 represents a random state). 

163 

164 .. note:: 

165 For any backend other than `jax`, passing the same `key` twice 

166 does not guarantee that the feature map will be the same each 

167 time. This is because these backends' random state has... a 

168 state. To evaluate the same (including randomness) feature map 

169 on different inputs, you can either save/restore state manually 

170 each time or use the helper function 

171 :func:`~.utils.make_deterministic` which does this for you. 

172 

173 :param normalize: 

174 Normalize to have unit average variance (`True` by default). 

175 :param ``**kwargs``: 

176 Unused. 

177 

178 :return: 

179 `Tuple(key, features)` where `features` is an [N, O] array, N 

180 is the number of inputs and O is the dimension of the feature map; 

181 `key` is the updated random key for `jax`, or the similar random 

182 state (generator) for any other backends. 

183 """ 

184 

185 key, random_phases = self.space.random_phases( 

186 key, self.num_random_phases 

187 ) # [O, D, D] 

188 

189 key, random_lambda = spd_density_sample( 

190 key, 

191 (self.num_random_phases,), 

192 params, 

193 self.space.degree, 

194 self.space.rho, 

195 self.shifted_laplacian, 

196 ) # [O, D] 

197 

198 random_phases_b = B.expand_dims( 

199 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_phases)) 

200 ) # [1, O, D, D] 

201 random_lambda_b = B.expand_dims( 

202 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_lambda)) 

203 ) # [1, O, D] 

204 X_b = B.expand_dims(X, axis=-3) # [N, 1, D, D] 

205 

206 p = self.space.power_function(random_lambda_b, X_b, random_phases_b) # [N, O] 

207 

208 features = B.concat(B.real(p), B.imag(p), axis=-1) # [N, 2*O] 

209 if normalize: 

210 normalizer = B.sqrt(B.sum(features**2, axis=-1, squeeze=False)) 

211 features = features / normalizer 

212 

213 return key, features