Coverage for geometric_kernels/utils/special_functions.py: 86%

29 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2Special mathematical functions used in the library. 

3""" 

4 

5import lab as B 

6from beartype.typing import List, Optional 

7 

8from geometric_kernels.lab_extras import ( 

9 count_nonzero, 

10 from_numpy, 

11 int_like, 

12 take_along_axis, 

13) 

14from geometric_kernels.utils.utils import _check_matrix 

15 

16 

17def walsh_function(d: int, combination: List[int], x: B.Bool) -> B.Float: 

18 r""" 

19 This function returns the value of the Walsh function 

20 

21 .. math:: w_T(x_0, .., x_{d-1}) = (-1)^{\sum_{j \in T} x_j} 

22 

23 where $d$ is `d`, $T$ is `combination`, and $x = (x_0, .., x_{d-1})$ is `x`. 

24 

25 :param d: 

26 The degree of the Walsh function and the dimension of its inputs. 

27 An integer d > 0. 

28 :param combination: 

29 A subset of the set $\{0, .., d-1\}$ determining the particular Walsh 

30 function. A list of integers. 

31 :param x: 

32 A batch of binary vectors $x = (x_0, .., x_{d-1})$ of shape [N, d]. 

33 

34 :return: 

35 The value of the Walsh function $w_T(x)$ evaluated for every $x$ in the 

36 batch. An array of shape [N]. 

37 

38 """ 

39 _check_matrix(x, "x") 

40 if B.shape(x)[-1] != d: 

41 raise ValueError("`x` must live in `d`-dimensional space.") 

42 

43 indices = B.cast(int_like(x), from_numpy(x, combination))[None, :] 

44 

45 return (-1) ** count_nonzero(take_along_axis(x, indices, axis=-1), axis=-1) 

46 

47 

48def kravchuk_normalized( 

49 d: int, 

50 j: int, 

51 m: B.Int, 

52 kravchuk_normalized_j_minus_1: Optional[B.Float] = None, 

53 kravchuk_normalized_j_minus_2: Optional[B.Float] = None, 

54) -> B.Float: 

55 r""" 

56 This function returns $G_{d, j, m}/G_{d, j, 0}$ where $G_{d, j, m}$ is the 

57 Kravchuk polynomial defined below. 

58 

59 Define the Kravchuk polynomial of degree d > 0 and order 0 <= j <= d as the 

60 function $G_{d, j, m}$ of the independent variable 0 <= m <= d given by 

61 

62 .. math:: G_{d, j, m} = \sum_{T \subseteq \{0, .., d-1\}, |T| = j} w_T(x). 

63 

64 Here $w_T$ are the Walsh functions on the hypercube graph $C^d$ and 

65 $x \in C^d$ is an arbitrary binary vector with $m$ ones (the right-hand side 

66 does not depend on the choice of a particular vector of the kind). 

67 

68 .. note:: 

69 We are using the three term recurrence relation to compute the Kravchuk 

70 polynomials. Cf. Equation (60) of Chapter 5 in MacWilliams and Sloane "The 

71 Theory of Error-Correcting Codes", 1977. The parameters q and $\gamma$ 

72 from :cite:t:`macwilliams1977` are set to be q = 2; $\gamma = q - 1 = 1$. 

73 

74 .. note:: 

75 We use the fact that $G_{d, j, 0} = \binom{d}{j}$. 

76 

77 :param d: 

78 The degree of Kravhuk polynomial, an integer d > 0. 

79 Maps to n in :cite:t:`macwilliams1977`. 

80 :param j: d 

81 The order of Kravhuk polynomial, an integer 0 <= j <= d. 

82 Maps to k in :cite:t:`macwilliams1977`. 

83 :param m: 

84 The independent variable, an integer 0 <= m <= d. 

85 Maps to x in :cite:t:`macwilliams1977`. 

86 :param kravchuk_normalized_j_minus_1: 

87 The optional precomputed value of $G_{d, j-1, m}/G_{d, j-1, 0}$, helps 

88 to avoid exponential complexity growth due to the recursion. 

89 :param kravchuk_normalized_j_minus_2: 

90 The optional precomputed value of $G_{d, j-2, m}/G_{d, j-2, 0}$, helps 

91 to avoid exponential complexity growth due to the recursion. 

92 

93 :return: 

94 $G_{d, j, m}/G_{d, j, 0}$ where $G_{d, j, m}$ is the Kravchuk polynomial. 

95 """ 

96 if d <= 0: 

97 raise ValueError("`d` must be positive.") 

98 if not (0 <= j and j <= d): 

99 raise ValueError("`j` must lie in the interval [0, d].") 

100 if not (B.all(0 <= m) and B.all(m <= d)): 

101 raise ValueError("`m` must lie in the interval [0, d].") 

102 

103 m = B.cast(B.dtype_float(m), m) 

104 

105 if j == 0: 

106 return B.ones(m) 

107 elif j == 1: 

108 return 1 - 2 * m / d 

109 else: 

110 if kravchuk_normalized_j_minus_1 is None: 

111 kravchuk_normalized_j_minus_1 = kravchuk_normalized(d, j - 1, m) 

112 if kravchuk_normalized_j_minus_2 is None: 

113 kravchuk_normalized_j_minus_2 = kravchuk_normalized(d, j - 2, m) 

114 rhs_1 = (d - 2 * m) * kravchuk_normalized_j_minus_1 

115 rhs_2 = -(j - 1) * kravchuk_normalized_j_minus_2 

116 return (rhs_1 + rhs_2) / (d - j + 1)