Coverage for geometric_kernels / utils / special_functions.py: 86%
29 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
1"""
2Special mathematical functions used in the library.
3"""
5import lab as B
6from beartype.typing import List, Optional
8from geometric_kernels.lab_extras import (
9 count_nonzero,
10 from_numpy,
11 int_like,
12 take_along_axis,
13)
14from geometric_kernels.utils.utils import _check_matrix
17def walsh_function(d: int, combination: List[int], x: B.Bool) -> B.Float:
18 r"""
19 This function returns the value of the Walsh function
21 .. math:: w_T(x_0, .., x_{d-1}) = (-1)^{\sum_{j \in T} x_j}
23 where $d$ is `d`, $T$ is `combination`, and $x = (x_0, .., x_{d-1})$ is `x`.
25 :param d:
26 The degree of the Walsh function and the dimension of its inputs.
27 An integer d > 0.
28 :param combination:
29 A subset of the set $\{0, .., d-1\}$ determining the particular Walsh
30 function. A list of integers.
31 :param x:
32 A batch of binary vectors $x = (x_0, .., x_{d-1})$ of shape [N, d].
34 :return:
35 The value of the Walsh function $w_T(x)$ evaluated for every $x$ in the
36 batch. An array of shape [N].
38 """
39 _check_matrix(x, "x")
40 if B.shape(x)[-1] != d:
41 raise ValueError("`x` must live in `d`-dimensional space.")
43 indices = B.cast(int_like(x), from_numpy(x, combination))[None, :]
45 return (-1) ** count_nonzero(take_along_axis(x, indices, axis=-1), axis=-1)
48def generalized_kravchuk_normalized(
49 d: int,
50 j: int,
51 m: B.Int,
52 q: int,
53 kravchuk_normalized_j_minus_1: Optional[B.Float] = None,
54 kravchuk_normalized_j_minus_2: Optional[B.Float] = None,
55) -> B.Float:
56 r"""
57 This function returns $\widetilde{G}_{d, q, j}(m)$ where $\widetilde{G}_{d, q, j}(m)$ is the
58 normalized generalized Kravchuk polynomial defined below.
60 Define the generalized Kravchuk polynomial of degree $d > 0$ and order $0 \leq j \leq d$
61 as the function $G_{d, q, j}(m)$ of the independent variable $0 \leq m \leq d$ given by
63 .. math:: G_{d, q, j}(m) = \sum_{\substack{T \subseteq \{0, .., d-1\} \\ |T| = j}}
64 \sum_{\alpha \in \{1,..,q-1\}^T} \psi_{T,\alpha}(x) \overline{\psi_{T,\alpha}(y)}.
66 Here $\psi_{T,\alpha}$ are the Vilenkin functions on the q-ary Hamming graph $H(d,q)$ and
67 $x, y \in \{0, 1, ..., q-1\}^d$ are arbitrary categorical vectors at Hamming distance $m$
68 (the right-hand side does not depend on the choice of particular vectors of the kind).
70 The normalized polynomial is $\widetilde{G}_{d, q, j}(m) = G_{d, q, j}(m) / G_{d, q, j}(0)$.
72 .. note::
73 We are using the three term recurrence relation to compute the Kravchuk
74 polynomials. Cf. Equation (60) of Chapter 5 in MacWilliams and Sloane "The
75 Theory of Error-Correcting Codes", 1977. The parameter $\gamma$ from
76 :cite:t:`macwilliams1977` is set to $\gamma = q - 1$.
78 .. note::
79 We use the fact that $G_{d, q, j}(0) = \binom{d}{j}(q-1)^j$.
81 .. note::
82 For $q = 2$, this reduces to the classical Kravchuk polynomials.
84 :param d:
85 The degree of Kravchuk polynomial, an integer $d > 0$.
86 Maps to $n$ in :cite:t:`macwilliams1977`.
88 :param j:
89 The order of Kravchuk polynomial, an integer $0 \leq j \leq d$.
90 Maps to $k$ in :cite:t:`macwilliams1977`.
92 :param m:
93 The independent variable (Hamming distance), an integer or array with
94 $0 \leq m \leq d$. Maps to $x$ in :cite:t:`macwilliams1977`.
96 :param q:
97 The alphabet size of the q-ary Hamming graph, an integer $q \geq 2$.
99 :param kravchuk_normalized_j_minus_1:
100 The optional precomputed value of $\widetilde{G}_{d, q, j-1}(m)$, helps
101 to avoid exponential complexity growth due to the recursion.
103 :param kravchuk_normalized_j_minus_2:
104 The optional precomputed value of $\widetilde{G}_{d, q, j-2}(m)$, helps
105 to avoid exponential complexity growth due to the recursion.
107 :return:
108 $\widetilde{G}_{d, q, j}(m) = G_{d, q, j}(m) / G_{d, q, j}(0)$ where
109 $G_{d, q, j}(m)$ is the generalized Kravchuk polynomial.
110 """
111 if d <= 0:
112 raise ValueError("`d` must be positive.")
113 if not (0 <= j and j <= d):
114 raise ValueError("`j` must lie in the interval [0, d].")
115 if not (B.all(0 <= m) and B.all(m <= d)):
116 raise ValueError("`m` must lie in the interval [0, d].")
118 m = B.cast(B.dtype_float(m), m)
120 if j == 0:
121 return B.ones(m)
122 elif j == 1:
123 return 1 - q * m / (d * (q - 1))
124 else:
125 if kravchuk_normalized_j_minus_1 is None:
126 kravchuk_normalized_j_minus_1 = generalized_kravchuk_normalized(
127 d, j - 1, m, q
128 )
129 if kravchuk_normalized_j_minus_2 is None:
130 kravchuk_normalized_j_minus_2 = generalized_kravchuk_normalized(
131 d, j - 2, m, q
132 )
134 rhs_1 = (
135 (d - j + 1) * (q - 1) + (j - 1) - q * m
136 ) * kravchuk_normalized_j_minus_1
137 rhs_2 = -(j - 1) * kravchuk_normalized_j_minus_2
139 return (rhs_1 + rhs_2) / ((d - j + 1) * (q - 1))