Coverage for tests/utils/test_special_functions.py: 100%
43 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1from math import comb
3import lab as B
4import numpy as np
5import pytest
7from geometric_kernels.utils.special_functions import (
8 kravchuk_normalized,
9 walsh_function,
10)
11from geometric_kernels.utils.utils import binary_vectors_and_subsets, hamming_distance
13from ..helper import check_function_with_backend
16@pytest.fixture(params=[1, 2, 3, 5, 10])
17def all_xs_and_combs(request):
18 """
19 Returns a tuple (d, x, combs) where:
20 - d is an integer equal to request.param,
21 - x is a 2**d x d boolean matrix with all possible binary vectors of length d,
22 - combs is a list of all possible combinations of indices of x.
23 """
24 d = request.param
26 X, combs = binary_vectors_and_subsets(d)
28 return d, X, combs
31def walsh_matrix(d, combs, X):
32 return B.stack(*[walsh_function(d, comb, X) for comb in combs])
35@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
36def test_walsh_functions(all_xs_and_combs, backend):
37 d, X, combs = all_xs_and_combs
39 # Check that Walsh functions are orthogonal.
40 check_function_with_backend(
41 backend,
42 2**d * np.eye(2**d),
43 lambda X: B.matmul(walsh_matrix(d, combs, X), B.T(walsh_matrix(d, combs, X))),
44 X,
45 )
47 # Check that Walsh functions only take values in the set {-1, 1}.
48 check_function_with_backend(
49 backend,
50 np.ones((2**d, 2**d)),
51 lambda X: B.abs(walsh_matrix(d, combs, X)),
52 X,
53 )
56@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
57def test_kravchuk_polynomials(all_xs_and_combs, backend):
58 d, X, combs = all_xs_and_combs
60 x0 = np.zeros((1, d), dtype=bool)
62 cur_ind = 0
63 for j in range(d + 1):
64 num_walsh = comb(d, j)
66 result = np.sum(
67 walsh_matrix(d, combs, X)[cur_ind : cur_ind + num_walsh, :],
68 axis=0,
69 keepdims=True,
70 )
72 def krav(x0, X):
73 return comb(d, j) * kravchuk_normalized(d, j, hamming_distance(x0, X))
75 # Checks that Kravchuk polynomials coincide with certain sums of
76 # the Walsh functions.
77 check_function_with_backend(
78 backend,
79 result,
80 krav,
81 x0,
82 X,
83 )
85 cur_ind += num_walsh
88@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"])
89def test_kravchuk_precomputed(all_xs_and_combs, backend):
90 d, X, _ = all_xs_and_combs
92 x0 = np.zeros((1, d), dtype=bool)
94 kravchuk_normalized_j_minus_1, kravchuk_normalized_j_minus_2 = None, None
95 for j in range(d + 1):
97 cur_kravchuk_normalized = kravchuk_normalized(d, j, hamming_distance(x0, X))
99 def krav(x0, X, kn1, kn2):
100 return kravchuk_normalized(d, j, hamming_distance(x0, X), kn1, kn2)
102 # Checks that Kravchuk polynomials coincide with certain sums of
103 # the Walsh functions.
104 check_function_with_backend(
105 backend,
106 cur_kravchuk_normalized,
107 krav,
108 x0,
109 X,
110 kravchuk_normalized_j_minus_1,
111 kravchuk_normalized_j_minus_2,
112 )
114 kravchuk_normalized_j_minus_2 = kravchuk_normalized_j_minus_1
115 kravchuk_normalized_j_minus_1 = cur_kravchuk_normalized