Coverage for geometric_kernels/utils/special_functions.py: 86%
29 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1"""
2Special mathematical functions used in the library.
3"""
5import lab as B
6from beartype.typing import List, Optional
8from geometric_kernels.lab_extras import (
9 count_nonzero,
10 from_numpy,
11 int_like,
12 take_along_axis,
13)
14from geometric_kernels.utils.utils import _check_matrix
17def walsh_function(d: int, combination: List[int], x: B.Bool) -> B.Float:
18 r"""
19 This function returns the value of the Walsh function
21 .. math:: w_T(x_0, .., x_{d-1}) = (-1)^{\sum_{j \in T} x_j}
23 where $d$ is `d`, $T$ is `combination`, and $x = (x_0, .., x_{d-1})$ is `x`.
25 :param d:
26 The degree of the Walsh function and the dimension of its inputs.
27 An integer d > 0.
28 :param combination:
29 A subset of the set $\{0, .., d-1\}$ determining the particular Walsh
30 function. A list of integers.
31 :param x:
32 A batch of binary vectors $x = (x_0, .., x_{d-1})$ of shape [N, d].
34 :return:
35 The value of the Walsh function $w_T(x)$ evaluated for every $x$ in the
36 batch. An array of shape [N].
38 """
39 _check_matrix(x, "x")
40 if B.shape(x)[-1] != d:
41 raise ValueError("`x` must live in `d`-dimensional space.")
43 indices = B.cast(int_like(x), from_numpy(x, combination))[None, :]
45 return (-1) ** count_nonzero(take_along_axis(x, indices, axis=-1), axis=-1)
48def kravchuk_normalized(
49 d: int,
50 j: int,
51 m: B.Int,
52 kravchuk_normalized_j_minus_1: Optional[B.Float] = None,
53 kravchuk_normalized_j_minus_2: Optional[B.Float] = None,
54) -> B.Float:
55 r"""
56 This function returns $G_{d, j, m}/G_{d, j, 0}$ where $G_{d, j, m}$ is the
57 Kravchuk polynomial defined below.
59 Define the Kravchuk polynomial of degree d > 0 and order 0 <= j <= d as the
60 function $G_{d, j, m}$ of the independent variable 0 <= m <= d given by
62 .. math:: G_{d, j, m} = \sum_{T \subseteq \{0, .., d-1\}, |T| = j} w_T(x).
64 Here $w_T$ are the Walsh functions on the hypercube graph $C^d$ and
65 $x \in C^d$ is an arbitrary binary vector with $m$ ones (the right-hand side
66 does not depend on the choice of a particular vector of the kind).
68 .. note::
69 We are using the three term recurrence relation to compute the Kravchuk
70 polynomials. Cf. Equation (60) of Chapter 5 in MacWilliams and Sloane "The
71 Theory of Error-Correcting Codes", 1977. The parameters q and $\gamma$
72 from :cite:t:`macwilliams1977` are set to be q = 2; $\gamma = q - 1 = 1$.
74 .. note::
75 We use the fact that $G_{d, j, 0} = \binom{d}{j}$.
77 :param d:
78 The degree of Kravhuk polynomial, an integer d > 0.
79 Maps to n in :cite:t:`macwilliams1977`.
80 :param j: d
81 The order of Kravhuk polynomial, an integer 0 <= j <= d.
82 Maps to k in :cite:t:`macwilliams1977`.
83 :param m:
84 The independent variable, an integer 0 <= m <= d.
85 Maps to x in :cite:t:`macwilliams1977`.
86 :param kravchuk_normalized_j_minus_1:
87 The optional precomputed value of $G_{d, j-1, m}/G_{d, j-1, 0}$, helps
88 to avoid exponential complexity growth due to the recursion.
89 :param kravchuk_normalized_j_minus_2:
90 The optional precomputed value of $G_{d, j-2, m}/G_{d, j-2, 0}$, helps
91 to avoid exponential complexity growth due to the recursion.
93 :return:
94 $G_{d, j, m}/G_{d, j, 0}$ where $G_{d, j, m}$ is the Kravchuk polynomial.
95 """
96 if d <= 0:
97 raise ValueError("`d` must be positive.")
98 if not (0 <= j and j <= d):
99 raise ValueError("`j` must lie in the interval [0, d].")
100 if not (B.all(0 <= m) and B.all(m <= d)):
101 raise ValueError("`m` must lie in the interval [0, d].")
103 m = B.cast(B.dtype_float(m), m)
105 if j == 0:
106 return B.ones(m)
107 elif j == 1:
108 return 1 - 2 * m / d
109 else:
110 if kravchuk_normalized_j_minus_1 is None:
111 kravchuk_normalized_j_minus_1 = kravchuk_normalized(d, j - 1, m)
112 if kravchuk_normalized_j_minus_2 is None:
113 kravchuk_normalized_j_minus_2 = kravchuk_normalized(d, j - 2, m)
114 rhs_1 = (d - 2 * m) * kravchuk_normalized_j_minus_1
115 rhs_2 = -(j - 1) * kravchuk_normalized_j_minus_2
116 return (rhs_1 + rhs_2) / (d - j + 1)