Coverage for tests/kernels/test_feature_map_kernel.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import lab as B 

2import numpy as np 

3import pytest 

4 

5from geometric_kernels.kernels import MaternFeatureMapKernel 

6from geometric_kernels.kernels.matern_kernel import default_feature_map, default_num 

7 

8from ..helper import ( 

9 check_function_with_backend, 

10 create_random_state, 

11 noncompact_symmetric_spaces, 

12) 

13 

14 

15@pytest.fixture( 

16 params=noncompact_symmetric_spaces(), 

17 ids=str, 

18 scope="module", 

19) 

20def inputs(request): 

21 """ 

22 Returns a tuple (space, num_features, feature_map, X, X2) where: 

23 - space = request.param, 

24 - num_features = default_num(space) or 15, whichever is smaller, 

25 - feature_map = default_feature_map(space=space, num=num_features), 

26 - X is a random sample of random size from the space, 

27 - X2 is another random sample of random size from the space, 

28 """ 

29 space = request.param 

30 num_features = min(default_num(space), 15) 

31 feature_map = default_feature_map(space=space, num=num_features) 

32 

33 key = np.random.RandomState(0) 

34 N, N2 = key.randint(low=1, high=100 + 1, size=2) 

35 key, X = space.random(key, N) 

36 key, X2 = space.random(key, N2) 

37 

38 return space, num_features, feature_map, X, X2 

39 

40 

41@pytest.fixture 

42def kernel(inputs, backend, normalize=True): 

43 space, _, feature_map, _, _ = inputs 

44 

45 key = create_random_state(backend) 

46 

47 return MaternFeatureMapKernel(space, feature_map, key, normalize=normalize) 

48 

49 

50@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

51def test_params(inputs, backend, kernel): 

52 params = kernel.init_params() 

53 

54 assert "lengthscale" in params 

55 assert params["lengthscale"].shape == (1,) 

56 assert "nu" in params 

57 assert params["nu"].shape == (1,) 

58 

59 

60@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

61@pytest.mark.parametrize("normalize", [True, False], ids=["normalize", "no_normalize"]) 

62def test_K(inputs, backend, normalize, kernel): 

63 _, _, _, X, X2 = inputs 

64 params = kernel.init_params() 

65 

66 # Check that kernel.K runs and the output is a tensor of the right backend and shape. 

67 check_function_with_backend( 

68 backend, 

69 (X.shape[0], X2.shape[0]), 

70 kernel.K, 

71 params, 

72 X, 

73 X2, 

74 compare_to_result=lambda res, f_out: B.shape(f_out) == res, 

75 ) 

76 

77 

78@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

79@pytest.mark.parametrize("normalize", [True, False], ids=["normalize", "no_normalize"]) 

80def test_K_one_param(inputs, backend, normalize, kernel): 

81 _, _, _, X, _ = inputs 

82 params = kernel.init_params() 

83 

84 # Check that kernel.K(X) coincides with kernel.K(X, X). 

85 check_function_with_backend( 

86 backend, 

87 np.zeros((X.shape[0], X.shape[0])), 

88 lambda params, X: kernel.K(params, X) - kernel.K(params, X, X), 

89 params, 

90 X, 

91 ) 

92 

93 

94@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

95@pytest.mark.parametrize("normalize", [True, False], ids=["normalize", "no_normalize"]) 

96def test_K_diag(inputs, backend, normalize, kernel): 

97 _, _, _, X, _ = inputs 

98 params = kernel.init_params() 

99 

100 # Check that kernel.K_diag coincides with the diagonal of kernel.K. 

101 check_function_with_backend( 

102 backend, 

103 np.zeros((X.shape[0],)), 

104 lambda params, X: kernel.K_diag(params, X) - B.diag(kernel.K(params, X)), 

105 params, 

106 X, 

107 ) 

108 

109 

110@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

111def test_normalize(inputs, backend, kernel): 

112 _, _, _, X, _ = inputs 

113 

114 params = kernel.init_params() 

115 

116 # Check that the variance of the kernel is constant 1. 

117 check_function_with_backend( 

118 backend, 

119 np.ones((X.shape[0],)), 

120 kernel.K_diag, 

121 params, 

122 X, 

123 )