Coverage for tests/utils/test_special_functions.py: 100%

43 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1from math import comb 

2 

3import lab as B 

4import numpy as np 

5import pytest 

6 

7from geometric_kernels.utils.special_functions import ( 

8 kravchuk_normalized, 

9 walsh_function, 

10) 

11from geometric_kernels.utils.utils import binary_vectors_and_subsets, hamming_distance 

12 

13from ..helper import check_function_with_backend 

14 

15 

16@pytest.fixture(params=[1, 2, 3, 5, 10]) 

17def all_xs_and_combs(request): 

18 """ 

19 Returns a tuple (d, x, combs) where: 

20 - d is an integer equal to request.param, 

21 - x is a 2**d x d boolean matrix with all possible binary vectors of length d, 

22 - combs is a list of all possible combinations of indices of x. 

23 """ 

24 d = request.param 

25 

26 X, combs = binary_vectors_and_subsets(d) 

27 

28 return d, X, combs 

29 

30 

31def walsh_matrix(d, combs, X): 

32 return B.stack(*[walsh_function(d, comb, X) for comb in combs]) 

33 

34 

35@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

36def test_walsh_functions(all_xs_and_combs, backend): 

37 d, X, combs = all_xs_and_combs 

38 

39 # Check that Walsh functions are orthogonal. 

40 check_function_with_backend( 

41 backend, 

42 2**d * np.eye(2**d), 

43 lambda X: B.matmul(walsh_matrix(d, combs, X), B.T(walsh_matrix(d, combs, X))), 

44 X, 

45 ) 

46 

47 # Check that Walsh functions only take values in the set {-1, 1}. 

48 check_function_with_backend( 

49 backend, 

50 np.ones((2**d, 2**d)), 

51 lambda X: B.abs(walsh_matrix(d, combs, X)), 

52 X, 

53 ) 

54 

55 

56@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

57def test_kravchuk_polynomials(all_xs_and_combs, backend): 

58 d, X, combs = all_xs_and_combs 

59 

60 x0 = np.zeros((1, d), dtype=bool) 

61 

62 cur_ind = 0 

63 for j in range(d + 1): 

64 num_walsh = comb(d, j) 

65 

66 result = np.sum( 

67 walsh_matrix(d, combs, X)[cur_ind : cur_ind + num_walsh, :], 

68 axis=0, 

69 keepdims=True, 

70 ) 

71 

72 def krav(x0, X): 

73 return comb(d, j) * kravchuk_normalized(d, j, hamming_distance(x0, X)) 

74 

75 # Checks that Kravchuk polynomials coincide with certain sums of 

76 # the Walsh functions. 

77 check_function_with_backend( 

78 backend, 

79 result, 

80 krav, 

81 x0, 

82 X, 

83 ) 

84 

85 cur_ind += num_walsh 

86 

87 

88@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

89def test_kravchuk_precomputed(all_xs_and_combs, backend): 

90 d, X, _ = all_xs_and_combs 

91 

92 x0 = np.zeros((1, d), dtype=bool) 

93 

94 kravchuk_normalized_j_minus_1, kravchuk_normalized_j_minus_2 = None, None 

95 for j in range(d + 1): 

96 

97 cur_kravchuk_normalized = kravchuk_normalized(d, j, hamming_distance(x0, X)) 

98 

99 def krav(x0, X, kn1, kn2): 

100 return kravchuk_normalized(d, j, hamming_distance(x0, X), kn1, kn2) 

101 

102 # Checks that Kravchuk polynomials coincide with certain sums of 

103 # the Walsh functions. 

104 check_function_with_backend( 

105 backend, 

106 cur_kravchuk_normalized, 

107 krav, 

108 x0, 

109 X, 

110 kravchuk_normalized_j_minus_1, 

111 kravchuk_normalized_j_minus_2, 

112 ) 

113 

114 kravchuk_normalized_j_minus_2 = kravchuk_normalized_j_minus_1 

115 kravchuk_normalized_j_minus_1 = cur_kravchuk_normalized