Coverage for tests / utils / test_special_functions.py: 100%

43 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-15 13:41 +0000

1from math import comb 

2 

3import lab as B 

4import numpy as np 

5import pytest 

6 

7from geometric_kernels.utils.special_functions import ( 

8 generalized_kravchuk_normalized, 

9 walsh_function, 

10) 

11from geometric_kernels.utils.utils import binary_vectors_and_subsets, hamming_distance 

12 

13from ..helper import check_function_with_backend 

14 

15 

16@pytest.fixture(params=[1, 2, 3, 5, 10]) 

17def all_xs_and_combs(request): 

18 """ 

19 Returns a tuple (d, x, combs) where: 

20 - d is an integer equal to request.param, 

21 - x is a 2**d x d boolean matrix with all possible binary vectors of length d, 

22 - combs is a list of all possible combinations of indices of x. 

23 """ 

24 d = request.param 

25 

26 X, combs = binary_vectors_and_subsets(d) 

27 

28 return d, X, combs 

29 

30 

31def walsh_matrix(d, combs, X): 

32 return B.stack(*[walsh_function(d, comb, X) for comb in combs]) 

33 

34 

35@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

36def test_walsh_functions(all_xs_and_combs, backend): 

37 d, X, combs = all_xs_and_combs 

38 

39 # Check that Walsh functions are orthogonal. 

40 check_function_with_backend( 

41 backend, 

42 2**d * np.eye(2**d), 

43 lambda X: B.matmul(walsh_matrix(d, combs, X), B.T(walsh_matrix(d, combs, X))), 

44 X, 

45 ) 

46 

47 # Check that Walsh functions only take values in the set {-1, 1}. 

48 check_function_with_backend( 

49 backend, 

50 np.ones((2**d, 2**d)), 

51 lambda X: B.abs(walsh_matrix(d, combs, X)), 

52 X, 

53 ) 

54 

55 

56@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

57def test_kravchuk_polynomials(all_xs_and_combs, backend): 

58 d, X, combs = all_xs_and_combs 

59 

60 x0 = np.zeros((1, d), dtype=bool) 

61 

62 cur_ind = 0 

63 for j in range(d + 1): 

64 num_walsh = comb(d, j) 

65 

66 result = np.sum( 

67 walsh_matrix(d, combs, X)[cur_ind : cur_ind + num_walsh, :], 

68 axis=0, 

69 keepdims=True, 

70 ) 

71 

72 def krav(x0, X): 

73 return comb(d, j) * generalized_kravchuk_normalized( 

74 d, j, hamming_distance(x0, X), 2 

75 ) 

76 

77 # Checks that Kravchuk polynomials coincide with certain sums of 

78 # the Walsh functions. 

79 check_function_with_backend( 

80 backend, 

81 result, 

82 krav, 

83 x0, 

84 X, 

85 ) 

86 

87 cur_ind += num_walsh 

88 

89 

90@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

91def test_kravchuk_precomputed(all_xs_and_combs, backend): 

92 d, X, _ = all_xs_and_combs 

93 

94 x0 = np.zeros((1, d), dtype=bool) 

95 

96 kravchuk_normalized_j_minus_1, kravchuk_normalized_j_minus_2 = None, None 

97 for j in range(d + 1): 

98 

99 cur_kravchuk_normalized = generalized_kravchuk_normalized( 

100 d, j, hamming_distance(x0, X), 2 

101 ) 

102 

103 def krav(x0, X, kn1, kn2): 

104 return generalized_kravchuk_normalized( 

105 d, j, hamming_distance(x0, X), 2, kn1, kn2 

106 ) 

107 

108 # Checks that Kravchuk polynomials coincide with certain sums of 

109 # the Walsh functions. 

110 check_function_with_backend( 

111 backend, 

112 cur_kravchuk_normalized, 

113 krav, 

114 x0, 

115 X, 

116 kravchuk_normalized_j_minus_1, 

117 kravchuk_normalized_j_minus_2, 

118 ) 

119 

120 kravchuk_normalized_j_minus_2 = kravchuk_normalized_j_minus_1 

121 kravchuk_normalized_j_minus_1 = cur_kravchuk_normalized