Coverage for tests/helper.py: 96%
74 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import lab as B
2import numpy as np
3from beartype.door import die_if_unbearable, is_bearable
4from beartype.typing import Any, Callable, List, Optional, Union
5from plum import ModuleType, resolve_type_hint
7from geometric_kernels.lab_extras import SparseArray
8from geometric_kernels.spaces import (
9 Circle,
10 CompactMatrixLieGroup,
11 DiscreteSpectrumSpace,
12 Graph,
13 GraphEdges,
14 HodgeDiscreteSpectrumSpace,
15 Hyperbolic,
16 HypercubeGraph,
17 Hypersphere,
18 Mesh,
19 NoncompactSymmetricSpace,
20 ProductDiscreteSpectrumSpace,
21 Space,
22 SpecialOrthogonal,
23 SpecialUnitary,
24 SymmetricPositiveDefiniteMatrices,
25)
27from .data import (
28 TEST_GRAPH_ADJACENCY,
29 TEST_GRAPH_EDGES_NUM_NODES,
30 TEST_GRAPH_EDGES_ORIENTED_EDGES,
31 TEST_GRAPH_EDGES_ORIENTED_TRIANGLES,
32 TEST_MESH_PATH,
33)
35EagerTensor = ModuleType("tensorflow.python.framework.ops", "EagerTensor")
38def compact_matrix_lie_groups() -> List[CompactMatrixLieGroup]:
39 return [
40 SpecialOrthogonal(3),
41 SpecialOrthogonal(8),
42 SpecialUnitary(2),
43 SpecialUnitary(5),
44 ]
47def product_discrete_spectrum_spaces() -> List[ProductDiscreteSpectrumSpace]:
48 return [
49 ProductDiscreteSpectrumSpace(Circle(), Hypersphere(3), Circle()),
50 ProductDiscreteSpectrumSpace(
51 Circle(), Graph(np.kron(TEST_GRAPH_ADJACENCY, TEST_GRAPH_ADJACENCY))
52 ), # TEST_GRAPH_ADJACENCY is too small for default parameters of the ProductDiscreteSpectrumSpace
53 ProductDiscreteSpectrumSpace(Mesh.load_mesh(TEST_MESH_PATH), Hypersphere(2)),
54 ]
57def hodge_discrete_spectrum_spaces() -> List[HodgeDiscreteSpectrumSpace]:
58 return [
59 GraphEdges(
60 TEST_GRAPH_EDGES_NUM_NODES,
61 TEST_GRAPH_EDGES_ORIENTED_EDGES,
62 TEST_GRAPH_EDGES_ORIENTED_TRIANGLES,
63 ),
64 ]
67def discrete_spectrum_spaces() -> List[DiscreteSpectrumSpace]:
68 return (
69 [
70 Circle(),
71 HypercubeGraph(1),
72 HypercubeGraph(3),
73 HypercubeGraph(6),
74 Hypersphere(2),
75 Hypersphere(3),
76 Hypersphere(10),
77 Mesh.load_mesh(TEST_MESH_PATH),
78 Graph(TEST_GRAPH_ADJACENCY, normalize_laplacian=False),
79 Graph(TEST_GRAPH_ADJACENCY, normalize_laplacian=True),
80 ]
81 + compact_matrix_lie_groups()
82 + product_discrete_spectrum_spaces()
83 + hodge_discrete_spectrum_spaces()
84 )
87def noncompact_symmetric_spaces() -> List[NoncompactSymmetricSpace]:
88 return [
89 Hyperbolic(2),
90 Hyperbolic(3),
91 Hyperbolic(8),
92 Hyperbolic(9),
93 SymmetricPositiveDefiniteMatrices(2),
94 SymmetricPositiveDefiniteMatrices(3),
95 SymmetricPositiveDefiniteMatrices(6),
96 SymmetricPositiveDefiniteMatrices(7),
97 ]
100def spaces() -> List[Space]:
101 return discrete_spectrum_spaces() + noncompact_symmetric_spaces()
104def np_to_backend(value: B.NPNumeric, backend: str):
105 """
106 Converts a numpy array to the desired backend.
108 :param value:
109 A numpy array.
110 :param backend:
111 The backend to use, one of the strings "tensorflow", "torch", "numpy", "jax".
113 :raises ValueError:
114 If the backend is not recognized.
116 :return:
117 The array `value` converted to the desired backend.
118 """
119 if backend == "tensorflow":
120 import tensorflow as tf
122 return tf.convert_to_tensor(value)
123 elif backend in ["torch", "pytorch"]:
124 import torch
126 return torch.tensor(value)
127 elif backend == "numpy":
128 return value
129 elif backend == "jax":
130 import jax.numpy as jnp
132 return jnp.array(value)
133 elif backend == "scipy_sparse":
134 import scipy.sparse as sp
136 return sp.csr_array(value)
137 else:
138 raise ValueError("Unknown backend: {}".format(backend))
141def create_random_state(backend: str, seed: int = 0):
142 dtype = B.dtype(np_to_backend(np.array([[1.0]]), backend))
143 return B.create_random_state(dtype, seed=seed)
146def array_type(backend: str):
147 """
148 Returns the array type corresponding to the given backend.
150 :param backend:
151 The backend to use, one of the strings "tensorflow", "torch", "numpy",
152 "jax", "scipy_sparse".
154 :return:
155 The array type corresponding to the given backend.
156 """
157 if backend == "tensorflow":
158 return resolve_type_hint(Union[B.TFNumeric, EagerTensor])
159 elif backend in ["torch", "pytorch"]:
160 return resolve_type_hint(B.TorchNumeric)
161 elif backend == "numpy":
162 return resolve_type_hint(B.NPNumeric)
163 elif backend == "jax":
164 return resolve_type_hint(B.JAXNumeric)
165 elif backend == "scipy_sparse":
166 return resolve_type_hint(SparseArray)
167 else:
168 raise ValueError(f"Unknown backend: {backend}")
171def apply_recursive(data: Any, func: Callable[[Any], Any]) -> Any:
172 """
173 Apply a function recursively to a nested data structure. Supports lists and
174 dictionaries.
176 :param data:
177 The data structure to apply the function to.
178 :param func:
179 The function to apply.
181 :return:
182 The data structure with the function applied to each element.
183 """
184 if isinstance(data, dict):
185 return {key: apply_recursive(value, func) for key, value in data.items()}
186 elif isinstance(data, list):
187 return [apply_recursive(element, func) for element in data]
188 else:
189 return func(data)
192def check_function_with_backend(
193 backend: str,
194 result: Any,
195 f: Callable,
196 *args: Any,
197 compare_to_result: Optional[Callable] = None,
198 atol=1e-4,
199):
200 """
201 1. Casts the arguments `*args` to the backend `backend`.
202 2. Runs the function `f` on the casted arguments.
203 3. Checks that the result is of the backend `backend`.
204 4. If no `compare_to_result` kwarg is provided, checks that the result,
205 casted back to numpy backend, coincides with the given `result`.
206 If `compare_to_result` is provided, checks if
207 `compare_to_result(result, f_output)` is True.
209 :param backend:
210 The backend to use, one of the strings "tensorflow", "torch", "numpy", "jax".
211 :param result:
212 The expected result of the function, if no `compare_to_result` kwarg is
213 provided, expected to be a numpy array. Otherwise, can be anything.
214 :param f:
215 The backend-independent function to run.
216 :param args:
217 The arguments to pass to the function `f`, expected to be numpy arrays
218 or non-array arguments.
219 :param compare_to_result:
220 A function that takes two arguments, the computed result and the
221 expected result, and returns a boolean.
222 :param atol:
223 The absolute tolerance to use when comparing the computed result with
224 the expected result.
225 """
227 def cast(arg):
228 if is_bearable(arg, B.Numeric):
229 # We only expect numpy arrays here
230 die_if_unbearable(arg, B.NPNumeric)
231 return np_to_backend(arg, backend)
232 else:
233 return arg
235 args_casted = (apply_recursive(arg, cast) for arg in args)
237 f_output = f(*args_casted)
238 assert is_bearable(
239 f_output, array_type(backend)
240 ), f"The output is not of the expected type. Expected: {array_type(backend)}, got: {type(f_output)}"
241 if compare_to_result is None:
242 # we convert `f_output` to numpy array to compare with `result``
243 if is_bearable(f_output, SparseArray):
244 f_output = f_output.toarray()
245 else:
246 f_output = B.to_numpy(f_output)
247 assert (
248 result.shape == f_output.shape
249 ), f"Shapes do not match: {result.shape} (for result) vs {f_output.shape} (for f_output)"
250 np.testing.assert_allclose(f_output, result, atol=atol)
251 else:
252 assert compare_to_result(
253 result, f_output
254 ), f"compare_to_result(result, f_output) failed with\n result:\n{result}\n\nf_output:\n{f_output}"