Coverage for tests / utils / test_special_functions.py: 100%
43 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
1from math import comb
3import lab as B
4import numpy as np
5import pytest
7from geometric_kernels.utils.special_functions import (
8 generalized_kravchuk_normalized,
9 walsh_function,
10)
11from geometric_kernels.utils.utils import binary_vectors_and_subsets, hamming_distance
13from ..helper import check_function_with_backend
16@pytest.fixture(params=[1, 2, 3, 5, 10])
17def all_xs_and_combs(request):
18 """
19 Returns a tuple (d, x, combs) where:
20 - d is an integer equal to request.param,
21 - x is a 2**d x d boolean matrix with all possible binary vectors of length d,
22 - combs is a list of all possible combinations of indices of x.
23 """
24 d = request.param
26 X, combs = binary_vectors_and_subsets(d)
28 return d, X, combs
31def walsh_matrix(d, combs, X):
32 return B.stack(*[walsh_function(d, comb, X) for comb in combs])
35@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
36def test_walsh_functions(all_xs_and_combs, backend):
37 d, X, combs = all_xs_and_combs
39 # Check that Walsh functions are orthogonal.
40 check_function_with_backend(
41 backend,
42 2**d * np.eye(2**d),
43 lambda X: B.matmul(walsh_matrix(d, combs, X), B.T(walsh_matrix(d, combs, X))),
44 X,
45 )
47 # Check that Walsh functions only take values in the set {-1, 1}.
48 check_function_with_backend(
49 backend,
50 np.ones((2**d, 2**d)),
51 lambda X: B.abs(walsh_matrix(d, combs, X)),
52 X,
53 )
56@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
57def test_kravchuk_polynomials(all_xs_and_combs, backend):
58 d, X, combs = all_xs_and_combs
60 x0 = np.zeros((1, d), dtype=bool)
62 cur_ind = 0
63 for j in range(d + 1):
64 num_walsh = comb(d, j)
66 result = np.sum(
67 walsh_matrix(d, combs, X)[cur_ind : cur_ind + num_walsh, :],
68 axis=0,
69 keepdims=True,
70 )
72 def krav(x0, X):
73 return comb(d, j) * generalized_kravchuk_normalized(
74 d, j, hamming_distance(x0, X), 2
75 )
77 # Checks that Kravchuk polynomials coincide with certain sums of
78 # the Walsh functions.
79 check_function_with_backend(
80 backend,
81 result,
82 krav,
83 x0,
84 X,
85 )
87 cur_ind += num_walsh
90@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
91def test_kravchuk_precomputed(all_xs_and_combs, backend):
92 d, X, _ = all_xs_and_combs
94 x0 = np.zeros((1, d), dtype=bool)
96 kravchuk_normalized_j_minus_1, kravchuk_normalized_j_minus_2 = None, None
97 for j in range(d + 1):
99 cur_kravchuk_normalized = generalized_kravchuk_normalized(
100 d, j, hamming_distance(x0, X), 2
101 )
103 def krav(x0, X, kn1, kn2):
104 return generalized_kravchuk_normalized(
105 d, j, hamming_distance(x0, X), 2, kn1, kn2
106 )
108 # Checks that Kravchuk polynomials coincide with certain sums of
109 # the Walsh functions.
110 check_function_with_backend(
111 backend,
112 cur_kravchuk_normalized,
113 krav,
114 x0,
115 X,
116 kravchuk_normalized_j_minus_1,
117 kravchuk_normalized_j_minus_2,
118 )
120 kravchuk_normalized_j_minus_2 = kravchuk_normalized_j_minus_1
121 kravchuk_normalized_j_minus_1 = cur_kravchuk_normalized