Coverage for tests/sampling/test_samplers.py: 100%
17 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import lab as B
2import numpy as np
3import pytest
5from geometric_kernels.sampling import sampler
7from ..feature_maps.test_feature_maps import feature_map_and_friends # noqa: F401
8from ..helper import check_function_with_backend, create_random_state
10_NUM_SAMPLES = 2
13@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
14def test_output_shape_and_backend(backend, feature_map_and_friends):
15 feature_map, kernel, space = feature_map_and_friends
17 params = kernel.init_params()
18 sample_paths = sampler(feature_map, s=_NUM_SAMPLES)
20 key = np.random.RandomState(0)
21 key, X = space.random(key, 50)
23 def sample(params, X):
24 return sample_paths(X, params, key=create_random_state(backend))[1]
26 # Check that sample_paths runs and the output is a tensor of the right backend and shape.
27 check_function_with_backend(
28 backend,
29 (X.shape[0], _NUM_SAMPLES),
30 sample,
31 params,
32 X,
33 compare_to_result=lambda res, f_out: B.shape(f_out) == res,
34 )