Coverage for tests/sampling/test_samplers.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import lab as B 

2import numpy as np 

3import pytest 

4 

5from geometric_kernels.sampling import sampler 

6 

7from ..feature_maps.test_feature_maps import feature_map_and_friends # noqa: F401 

8from ..helper import check_function_with_backend, create_random_state 

9 

10_NUM_SAMPLES = 2 

11 

12 

13@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

14def test_output_shape_and_backend(backend, feature_map_and_friends): 

15 feature_map, kernel, space = feature_map_and_friends 

16 

17 params = kernel.init_params() 

18 sample_paths = sampler(feature_map, s=_NUM_SAMPLES) 

19 

20 key = np.random.RandomState(0) 

21 key, X = space.random(key, 50) 

22 

23 def sample(params, X): 

24 return sample_paths(X, params, key=create_random_state(backend))[1] 

25 

26 # Check that sample_paths runs and the output is a tensor of the right backend and shape. 

27 check_function_with_backend( 

28 backend, 

29 (X.shape[0], _NUM_SAMPLES), 

30 sample, 

31 params, 

32 X, 

33 compare_to_result=lambda res, f_out: B.shape(f_out) == res, 

34 )