Coverage for geometric_kernels/sampling/samplers.py: 92%

25 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2Samplers. 

3""" 

4 

5from __future__ import annotations # By https://stackoverflow.com/a/62136491 

6 

7from functools import partial 

8 

9import lab as B 

10from beartype.typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple 

11 

12# By https://stackoverflow.com/a/62136491 

13if TYPE_CHECKING: 

14 from geometric_kernels.feature_maps import FeatureMap 

15 

16 

17def sample_at( 

18 feature_map: FeatureMap, 

19 s: int, 

20 X: B.Numeric, 

21 params: Dict[str, B.Numeric], 

22 key: B.RandomState = None, 

23 normalize: bool = None, 

24 hodge_type: Optional[str] = None, 

25) -> Tuple[B.RandomState, B.Numeric]: 

26 r""" 

27 Given a `feature_map` $\phi_{\nu, \kappa}: X \to \mathbb{R}^n$, where 

28 $\nu, \kappa$ are determined by `params["nu"]` and `params["lengthscale"]`, 

29 respectively, compute `s` samples of the Gaussian process with kernel 

30 

31 .. math :: k_{\nu, \kappa}(x, y) = \langle \phi_{\nu, \kappa}(x), \phi_{\nu, \kappa}(y) \rangle_{\mathbb{R}^n} 

32 

33 at input locations `X` and using the random state `key`. 

34 

35 Generating a sample from $GP(0, k_{\nu, \kappa})$ is as simple as computing 

36 

37 .. math :: \sum_{j=1}^n w_j \cdot (\phi_{\nu, \kappa}(x))_j \qquad w_j \stackrel{IID}{\sim} N(0, 1). 

38 

39 .. note:: 

40 Fixing $w_j$, and treating $x \to (\phi_{\nu, \kappa}(x))_j$ as basis 

41 functions, while letting $x$ vary, you get an actual *function* as a 

42 sample, meaning something that can be evaluated at any $x \in X$. 

43 The way to fix $w_j$ in code is to apply :func:`~.make_deterministic` 

44 utility function. 

45 

46 :param feature_map: 

47 The feature map $\phi_{\nu, \kappa}$ that defines the Gaussian process 

48 $GP(0, k_{\nu, \kappa})$ to sample from. 

49 :param s: 

50 The number of samples to generate. 

51 :param X: 

52 An [N, <axis>]-shaped array containing N elements of the space 

53 `feature_map` is defined on. <axis> is the shape of these elements. 

54 These are the points to evaluate the samples at. 

55 :param params: 

56 Parameters of the kernel (length scale and smoothness). 

57 :param key: random state, either `np.random.RandomState`, 

58 `tf.random.Generator`, `torch.Generator` or `jax.tensor` (which 

59 represents a random state). 

60 :param normalize: 

61 Passed down to `feature_map` directly. Controls whether to force the 

62 average variance of the Gaussian process to be around 1 or not. If None, 

63 follows the standard behavior, typically same as normalize=True. 

64 

65 Defaults to None. 

66 :param hodge_type: 

67 The type of Hodge component to sample. Only used when using Hodge-compositional edge kernels 

68 

69 Defaults to None. 

70 

71 :return: 

72 [N, s]-shaped array containing s samples of the $GP(0, k_{\nu, \kappa})$ 

73 evaluated at X. 

74 """ 

75 

76 if key is None: 

77 key = B.global_random_state(B.dtype(X)) 

78 

79 if hodge_type is None: 

80 _context, features = feature_map( 

81 X, params, key=key, normalize=normalize 

82 ) # [N, M] 

83 else: 

84 _context, features = feature_map( 

85 X, params, key=key, normalize=normalize, hodge_type=hodge_type 

86 ) 

87 

88 if _context is not None: 

89 key = _context 

90 

91 num_features = B.shape(features)[-1] 

92 

93 if "lengthscale" in params: 

94 dtype = B.dtype(params["lengthscale"]) 

95 else: 

96 dtype = B.dtype(params["gradient"]["lengthscale"]) 

97 

98 key, random_weights = B.randn(key, dtype, num_features, s) # [M, S] 

99 

100 random_sample = B.matmul(features, random_weights) # [N, S] 

101 

102 return key, random_sample 

103 

104 

105def sampler( 

106 feature_map: FeatureMap, s: Optional[int] = 1, **kwargs 

107) -> Callable[[Any], Any]: 

108 """ 

109 A helper wrapper around `sample_at` that fixes `feature_map`, `s` and the 

110 keyword arguments in ``**kwargs`` but leaves `X`, `params` and the other 

111 keyword arguments vary. 

112 

113 :param feature_map: 

114 The feature map to fix. 

115 :param s: 

116 The number of samples parameter to fix. 

117 

118 Defaults to 1. 

119 :param ``**kwargs``: 

120 Keyword arguments to fix. 

121 

122 :return: 

123 The version of :func:`sample_at` with parameters `feature_map` and `s` 

124 fixed, together with the keyword arguments in ``**kwargs``. 

125 """ 

126 

127 sample_f = partial(sample_at, feature_map, s, **kwargs) 

128 new_docstring = f""" 

129 This is a version of the `{sample_at.__name__}` function with 

130 - feature_map={feature_map}, 

131 - s={s}, 

132 - and additional keyword arguments {kwargs}. 

133 

134 The original docstring follows. 

135 

136 {sample_at.__doc__} 

137 """ 

138 sample_f.__name__ = sample_at.__name__ # type: ignore[attr-defined] 

139 sample_f.__doc__ = new_docstring 

140 

141 return sample_f