Coverage for tests / helper.py: 96%
74 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-15 13:41 +0000
1import lab as B
2import numpy as np
3from beartype.door import die_if_unbearable, is_bearable
4from beartype.typing import Any, Callable, List, Optional, Union
5from plum import ModuleType, resolve_type_hint
7from geometric_kernels.lab_extras import SparseArray
8from geometric_kernels.spaces import (
9 Circle,
10 CompactMatrixLieGroup,
11 DiscreteSpectrumSpace,
12 Graph,
13 GraphEdges,
14 HammingGraph,
15 HodgeDiscreteSpectrumSpace,
16 Hyperbolic,
17 HypercubeGraph,
18 Hypersphere,
19 Mesh,
20 NoncompactSymmetricSpace,
21 ProductDiscreteSpectrumSpace,
22 Space,
23 SpecialOrthogonal,
24 SpecialUnitary,
25 SymmetricPositiveDefiniteMatrices,
26)
28from .data import (
29 TEST_GRAPH_ADJACENCY,
30 TEST_GRAPH_EDGES_NUM_NODES,
31 TEST_GRAPH_EDGES_ORIENTED_EDGES,
32 TEST_GRAPH_EDGES_ORIENTED_TRIANGLES,
33 TEST_MESH_PATH,
34)
36EagerTensor = ModuleType("tensorflow.python.framework.ops", "EagerTensor")
39def compact_matrix_lie_groups() -> List[CompactMatrixLieGroup]:
40 return [
41 SpecialOrthogonal(3),
42 SpecialOrthogonal(8),
43 SpecialUnitary(2),
44 SpecialUnitary(5),
45 ]
48def product_discrete_spectrum_spaces() -> List[ProductDiscreteSpectrumSpace]:
49 return [
50 ProductDiscreteSpectrumSpace(Circle(), Hypersphere(3), Circle()),
51 ProductDiscreteSpectrumSpace(
52 Circle(), Graph(np.kron(TEST_GRAPH_ADJACENCY, TEST_GRAPH_ADJACENCY))
53 ), # TEST_GRAPH_ADJACENCY is too small for default parameters of the ProductDiscreteSpectrumSpace
54 ProductDiscreteSpectrumSpace(Mesh.load_mesh(TEST_MESH_PATH), Hypersphere(2)),
55 ]
58def hodge_discrete_spectrum_spaces() -> List[HodgeDiscreteSpectrumSpace]:
59 return [
60 GraphEdges(
61 TEST_GRAPH_EDGES_NUM_NODES,
62 TEST_GRAPH_EDGES_ORIENTED_EDGES,
63 TEST_GRAPH_EDGES_ORIENTED_TRIANGLES,
64 ),
65 ]
68def discrete_spectrum_spaces() -> List[DiscreteSpectrumSpace]:
69 return (
70 [
71 Circle(),
72 HypercubeGraph(1),
73 HypercubeGraph(3),
74 HypercubeGraph(6),
75 HammingGraph(1, 2),
76 HammingGraph(3, 3),
77 HammingGraph(2, 6),
78 HammingGraph(6, 4),
79 Hypersphere(2),
80 Hypersphere(3),
81 Hypersphere(10),
82 Mesh.load_mesh(TEST_MESH_PATH),
83 Graph(TEST_GRAPH_ADJACENCY, normalize_laplacian=False),
84 Graph(TEST_GRAPH_ADJACENCY, normalize_laplacian=True),
85 ]
86 + compact_matrix_lie_groups()
87 + product_discrete_spectrum_spaces()
88 + hodge_discrete_spectrum_spaces()
89 )
92def noncompact_symmetric_spaces() -> List[NoncompactSymmetricSpace]:
93 return [
94 Hyperbolic(2),
95 Hyperbolic(3),
96 Hyperbolic(8),
97 Hyperbolic(9),
98 SymmetricPositiveDefiniteMatrices(2),
99 SymmetricPositiveDefiniteMatrices(3),
100 SymmetricPositiveDefiniteMatrices(6),
101 SymmetricPositiveDefiniteMatrices(7),
102 ]
105def spaces() -> List[Space]:
106 return discrete_spectrum_spaces() + noncompact_symmetric_spaces()
109def np_to_backend(value: B.NPNumeric, backend: str):
110 """
111 Converts a numpy array to the desired backend.
113 :param value:
114 A numpy array.
115 :param backend:
116 The backend to use, one of the strings "tensorflow", "torch", "numpy", "jax".
118 :raises ValueError:
119 If the backend is not recognized.
121 :return:
122 The array `value` converted to the desired backend.
123 """
124 if backend == "tensorflow":
125 import tensorflow as tf
127 return tf.convert_to_tensor(value)
128 elif backend in ["torch", "pytorch"]:
129 import torch
131 return torch.tensor(value)
132 elif backend == "numpy":
133 return value
134 elif backend == "jax":
135 import jax.numpy as jnp
137 return jnp.array(value)
138 elif backend == "scipy_sparse":
139 import scipy.sparse as sp
141 return sp.csr_array(value)
142 else:
143 raise ValueError("Unknown backend: {}".format(backend))
146def create_random_state(backend: str, seed: int = 0):
147 dtype = B.dtype(np_to_backend(np.array([[1.0]]), backend))
148 return B.create_random_state(dtype, seed=seed)
151def array_type(backend: str):
152 """
153 Returns the array type corresponding to the given backend.
155 :param backend:
156 The backend to use, one of the strings "tensorflow", "torch", "numpy",
157 "jax", "scipy_sparse".
159 :return:
160 The array type corresponding to the given backend.
161 """
162 if backend == "tensorflow":
163 return resolve_type_hint(Union[B.TFNumeric, EagerTensor])
164 elif backend in ["torch", "pytorch"]:
165 return resolve_type_hint(B.TorchNumeric)
166 elif backend == "numpy":
167 return resolve_type_hint(B.NPNumeric)
168 elif backend == "jax":
169 return resolve_type_hint(B.JAXNumeric)
170 elif backend == "scipy_sparse":
171 return resolve_type_hint(SparseArray)
172 else:
173 raise ValueError(f"Unknown backend: {backend}")
176def apply_recursive(data: Any, func: Callable[[Any], Any]) -> Any:
177 """
178 Apply a function recursively to a nested data structure. Supports lists and
179 dictionaries.
181 :param data:
182 The data structure to apply the function to.
183 :param func:
184 The function to apply.
186 :return:
187 The data structure with the function applied to each element.
188 """
189 if isinstance(data, dict):
190 return {key: apply_recursive(value, func) for key, value in data.items()}
191 elif isinstance(data, list):
192 return [apply_recursive(element, func) for element in data]
193 else:
194 return func(data)
197def check_function_with_backend(
198 backend: str,
199 result: Any,
200 f: Callable,
201 *args: Any,
202 compare_to_result: Optional[Callable] = None,
203 atol=1e-4,
204):
205 """
206 1. Casts the arguments `*args` to the backend `backend`.
207 2. Runs the function `f` on the casted arguments.
208 3. Checks that the result is of the backend `backend`.
209 4. If no `compare_to_result` kwarg is provided, checks that the result,
210 casted back to numpy backend, coincides with the given `result`.
211 If `compare_to_result` is provided, checks if
212 `compare_to_result(result, f_output)` is True.
214 :param backend:
215 The backend to use, one of the strings "tensorflow", "torch", "numpy", "jax".
216 :param result:
217 The expected result of the function, if no `compare_to_result` kwarg is
218 provided, expected to be a numpy array. Otherwise, can be anything.
219 :param f:
220 The backend-independent function to run.
221 :param args:
222 The arguments to pass to the function `f`, expected to be numpy arrays
223 or non-array arguments.
224 :param compare_to_result:
225 A function that takes two arguments, the computed result and the
226 expected result, and returns a boolean.
227 :param atol:
228 The absolute tolerance to use when comparing the computed result with
229 the expected result.
230 """
232 def cast(arg):
233 if is_bearable(arg, B.Numeric):
234 # We only expect numpy arrays here
235 die_if_unbearable(arg, B.NPNumeric)
236 return np_to_backend(arg, backend)
237 else:
238 return arg
240 args_casted = (apply_recursive(arg, cast) for arg in args)
242 f_output = f(*args_casted)
243 assert is_bearable(
244 f_output, array_type(backend)
245 ), f"The output is not of the expected type. Expected: {array_type(backend)}, got: {type(f_output)}"
246 if compare_to_result is None:
247 # we convert `f_output` to numpy array to compare with `result``
248 if is_bearable(f_output, SparseArray):
249 f_output = f_output.toarray()
250 else:
251 f_output = B.to_numpy(f_output)
252 assert (
253 result.shape == f_output.shape
254 ), f"Shapes do not match: {result.shape} (for result) vs {f_output.shape} (for f_output)"
255 np.testing.assert_allclose(f_output, result, atol=atol)
256 else:
257 assert compare_to_result(
258 result, f_output
259 ), f"compare_to_result(result, f_output) failed with\n result:\n{result}\n\nf_output:\n{f_output}"