Coverage for geometric_kernels/feature_maps/random_phase.py: 70%

50 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1r""" 

2This module provides the random phase-based feature maps. 

3 

4Specifically, it provides a random phase-based feature map for 

5:class:`~.spaces.DiscreteSpectrumSpace`\ s for which the 

6:doc:`addition theorem </theory/addition_theorem>`-like basis functions 

7are explicitly available while the actual eigenpairs may remain implicit. 

8 

9It also provides a basic random phase-based feature map for 

10:class:`~.spaces.NoncompactSymmetricSpace`\ s. It should be used unless a more 

11specialized per-space implementation is available, like the ones in the module 

12:mod:`geometric_kernels.feature_maps.rejection_sampling`. 

13""" 

14 

15import lab as B 

16from beartype.typing import Dict, Tuple 

17 

18from geometric_kernels.feature_maps.base import FeatureMap 

19from geometric_kernels.feature_maps.probability_densities import base_density_sample 

20from geometric_kernels.lab_extras import complex_like, from_numpy, is_complex 

21from geometric_kernels.spaces import DiscreteSpectrumSpace, NoncompactSymmetricSpace 

22 

23 

24class RandomPhaseFeatureMapCompact(FeatureMap): 

25 r""" 

26 Random phase feature map for :class:`~.spaces.DiscreteSpectrumSpace`\ s for 

27 which the :doc:`addition theorem </theory/addition_theorem>`-like basis 

28 functions are explicitly available while actual eigenpairs may be implicit. 

29 

30 :param space: 

31 A :class:`~.spaces.DiscreteSpectrumSpace` space. 

32 :param num_levels: 

33 Number of levels in the kernel approximation. 

34 :param num_random_phases: 

35 Number of random phases used in the generalized 

36 random phase Fourier features technique. 

37 """ 

38 

39 def __init__( 

40 self, 

41 space: DiscreteSpectrumSpace, 

42 num_levels: int, 

43 num_random_phases: int = 3000, 

44 ): 

45 self.space = space 

46 self.num_levels = num_levels 

47 self.num_random_phases = num_random_phases 

48 self.eigenfunctions = space.get_eigenfunctions(num_levels) 

49 

50 def __call__( 

51 self, 

52 X: B.Numeric, 

53 params: Dict[str, B.Numeric], 

54 *, 

55 key: B.RandomState, 

56 normalize: bool = True, 

57 **kwargs, 

58 ) -> Tuple[B.RandomState, B.Numeric]: 

59 """ 

60 :param X: 

61 [N, ...] points in the space to evaluate the map on. 

62 

63 :param params: 

64 Parameters of the kernel (length scale and smoothness). 

65 

66 :param key: 

67 Random state, either `np.random.RandomState`, 

68 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which 

69 represents a random state). 

70 

71 .. note:: 

72 For any backend other than `jax`, passing the same `key` twice 

73 does not guarantee that the feature map will be the same each 

74 time. This is because these backends' random state has... a 

75 state. To evaluate the same (including randomness) feature map 

76 on different inputs, you can either save/restore state manually 

77 each time or use the helper function 

78 :func:`~.utils.make_deterministic` which 

79 does this for you. 

80 

81 :param normalize: 

82 Normalize to have unit average variance. If omitted, set to True. 

83 

84 :param ``**kwargs``: 

85 Unused. 

86 

87 :return: 

88 `Tuple(key, features)` where `features` is an [N, O] array, N 

89 is the number of inputs and O is the dimension of the feature map; 

90 `key` is the updated random key for `jax`, or the similar random 

91 state (generator) for any other backends. 

92 """ 

93 from geometric_kernels.kernels.karhunen_loeve import MaternKarhunenLoeveKernel 

94 

95 key, random_phases = self.space.random(key, self.num_random_phases) # [O, D] 

96 eigenvalues = self.space.get_eigenvalues(self.num_levels) 

97 

98 spectrum = MaternKarhunenLoeveKernel.spectrum( 

99 eigenvalues, 

100 nu=params["nu"], 

101 lengthscale=params["lengthscale"], 

102 dimension=self.space.dimension, 

103 ) 

104 

105 if is_complex(X): 

106 dtype = complex_like(params["lengthscale"]) 

107 else: 

108 dtype = B.dtype(params["lengthscale"]) 

109 

110 weights = B.power(spectrum, 0.5) # [L, 1] 

111 

112 random_phases_b = B.cast(dtype, from_numpy(X, random_phases)) 

113 

114 phi_product = self.eigenfunctions.phi_product( 

115 X, random_phases_b, **params 

116 ) # [N, O, L] 

117 

118 embedding = B.cast(dtype, phi_product) # [N, O, L] 

119 weights_t = B.cast(dtype, B.transpose(weights)) 

120 

121 features = B.reshape(embedding * weights_t, B.shape(X)[0], -1) # [N, O*L] 

122 if is_complex(features): 

123 features = B.concat(B.real(features), B.imag(features), axis=1) 

124 

125 if normalize: 

126 normalizer = B.sqrt(B.sum(features**2, axis=-1, squeeze=False)) 

127 features = features / normalizer 

128 

129 return key, features 

130 

131 

132class RandomPhaseFeatureMapNoncompact(FeatureMap): 

133 r""" 

134 Basic random phase feature map for 

135 :class:`~.spaces.NoncompactSymmetricSpace`\ s (importance sampling based). 

136 

137 This feature map should not be used if a space-specific alternative exists. 

138 

139 :param space: 

140 A :class:`~.spaces.NoncompactSymmetricSpace` space. 

141 :param num_random_phases: 

142 Number of random phases to use. 

143 :param shifted_laplacian: 

144 If True, assumes that the kernels are defined in terms of the shifted 

145 Laplacian. This often makes Matérn kernels more flexible by widening 

146 the effective range of the length scale parameter. 

147 

148 Defaults to True. 

149 """ 

150 

151 def __init__( 

152 self, 

153 space: NoncompactSymmetricSpace, 

154 num_random_phases: int = 3000, 

155 shifted_laplacian: bool = True, 

156 ): 

157 self.space = space 

158 self.num_random_phases = num_random_phases 

159 self.shifted_laplacian = shifted_laplacian 

160 

161 def __call__( 

162 self, 

163 X: B.Numeric, 

164 params: Dict[str, B.Numeric], 

165 *, 

166 key: B.RandomState, 

167 normalize: bool = True, 

168 **kwargs, 

169 ) -> Tuple[B.RandomState, B.Numeric]: 

170 """ 

171 :param X: 

172 [N, ...] points in the space to evaluate the map on. 

173 :param params: 

174 Parameters of the feature map (length scale and smoothness). 

175 :param key: 

176 Random state, either `np.random.RandomState`, 

177 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which 

178 represents a random state). 

179 

180 .. note:: 

181 For any backend other than `jax`, passing the same `key` twice 

182 does not guarantee that the feature map will be the same each 

183 time. This is because these backends' random state has... a 

184 state. To evaluate the same (including randomness) feature map 

185 on different inputs, you can either save/restore state manually 

186 each time or use the helper function 

187 :func:`~.utils.make_deterministic` which 

188 does this for you. 

189 

190 :param normalize: 

191 Normalize to have unit average variance (`True` by default). 

192 :param ``**kwargs``: 

193 Unused. 

194 

195 :return: `Tuple(key, features)` where `features` is an [N, O] array, N 

196 is the number of inputs and O is the dimension of the feature map; 

197 `key` is the updated random key for `jax`, or the similar random 

198 state (generator) for any other backends. 

199 """ 

200 

201 key, random_phases = self.space.random_phases( 

202 key, self.num_random_phases 

203 ) # [O, <axes_p>] 

204 

205 key, random_lambda = base_density_sample( 

206 key, 

207 (self.num_random_phases,), # [O, 1] 

208 params, 

209 self.space.dimension, 

210 self.space.rho, 

211 self.shifted_laplacian, 

212 ) # [O, P] 

213 

214 random_phases_b = B.expand_dims( 

215 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_phases)) 

216 ) # [1, O, <axes_p>] 

217 random_lambda_b = B.expand_dims( 

218 B.cast(B.dtype(params["lengthscale"]), from_numpy(X, random_lambda)) 

219 ) # [1, O, P] 

220 X_b = B.expand_dims(X, axis=-1 - self.space.num_axes) # [N, 1, <axes>] 

221 

222 p = self.space.power_function(random_lambda_b, X_b, random_phases_b) # [N, O] 

223 c = self.space.inv_harish_chandra(random_lambda_b) # [1, O] 

224 

225 features = B.concat(B.real(p) * c, B.imag(p) * c, axis=-1) # [N, 2*O] 

226 if normalize: 

227 normalizer = B.sqrt(B.sum(features**2, axis=-1, squeeze=False)) 

228 features = features / normalizer 

229 

230 return key, features