Coverage for tests/helper.py: 96%

74 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import lab as B 

2import numpy as np 

3from beartype.door import die_if_unbearable, is_bearable 

4from beartype.typing import Any, Callable, List, Optional, Union 

5from plum import ModuleType, resolve_type_hint 

6 

7from geometric_kernels.lab_extras import SparseArray 

8from geometric_kernels.spaces import ( 

9 Circle, 

10 CompactMatrixLieGroup, 

11 DiscreteSpectrumSpace, 

12 Graph, 

13 GraphEdges, 

14 HodgeDiscreteSpectrumSpace, 

15 Hyperbolic, 

16 HypercubeGraph, 

17 Hypersphere, 

18 Mesh, 

19 NoncompactSymmetricSpace, 

20 ProductDiscreteSpectrumSpace, 

21 Space, 

22 SpecialOrthogonal, 

23 SpecialUnitary, 

24 SymmetricPositiveDefiniteMatrices, 

25) 

26 

27from .data import ( 

28 TEST_GRAPH_ADJACENCY, 

29 TEST_GRAPH_EDGES_NUM_NODES, 

30 TEST_GRAPH_EDGES_ORIENTED_EDGES, 

31 TEST_GRAPH_EDGES_ORIENTED_TRIANGLES, 

32 TEST_MESH_PATH, 

33) 

34 

35EagerTensor = ModuleType("tensorflow.python.framework.ops", "EagerTensor") 

36 

37 

38def compact_matrix_lie_groups() -> List[CompactMatrixLieGroup]: 

39 return [ 

40 SpecialOrthogonal(3), 

41 SpecialOrthogonal(8), 

42 SpecialUnitary(2), 

43 SpecialUnitary(5), 

44 ] 

45 

46 

47def product_discrete_spectrum_spaces() -> List[ProductDiscreteSpectrumSpace]: 

48 return [ 

49 ProductDiscreteSpectrumSpace(Circle(), Hypersphere(3), Circle()), 

50 ProductDiscreteSpectrumSpace( 

51 Circle(), Graph(np.kron(TEST_GRAPH_ADJACENCY, TEST_GRAPH_ADJACENCY)) 

52 ), # TEST_GRAPH_ADJACENCY is too small for default parameters of the ProductDiscreteSpectrumSpace 

53 ProductDiscreteSpectrumSpace(Mesh.load_mesh(TEST_MESH_PATH), Hypersphere(2)), 

54 ] 

55 

56 

57def hodge_discrete_spectrum_spaces() -> List[HodgeDiscreteSpectrumSpace]: 

58 return [ 

59 GraphEdges( 

60 TEST_GRAPH_EDGES_NUM_NODES, 

61 TEST_GRAPH_EDGES_ORIENTED_EDGES, 

62 TEST_GRAPH_EDGES_ORIENTED_TRIANGLES, 

63 ), 

64 ] 

65 

66 

67def discrete_spectrum_spaces() -> List[DiscreteSpectrumSpace]: 

68 return ( 

69 [ 

70 Circle(), 

71 HypercubeGraph(1), 

72 HypercubeGraph(3), 

73 HypercubeGraph(6), 

74 Hypersphere(2), 

75 Hypersphere(3), 

76 Hypersphere(10), 

77 Mesh.load_mesh(TEST_MESH_PATH), 

78 Graph(TEST_GRAPH_ADJACENCY, normalize_laplacian=False), 

79 Graph(TEST_GRAPH_ADJACENCY, normalize_laplacian=True), 

80 ] 

81 + compact_matrix_lie_groups() 

82 + product_discrete_spectrum_spaces() 

83 + hodge_discrete_spectrum_spaces() 

84 ) 

85 

86 

87def noncompact_symmetric_spaces() -> List[NoncompactSymmetricSpace]: 

88 return [ 

89 Hyperbolic(2), 

90 Hyperbolic(3), 

91 Hyperbolic(8), 

92 Hyperbolic(9), 

93 SymmetricPositiveDefiniteMatrices(2), 

94 SymmetricPositiveDefiniteMatrices(3), 

95 SymmetricPositiveDefiniteMatrices(6), 

96 SymmetricPositiveDefiniteMatrices(7), 

97 ] 

98 

99 

100def spaces() -> List[Space]: 

101 return discrete_spectrum_spaces() + noncompact_symmetric_spaces() 

102 

103 

104def np_to_backend(value: B.NPNumeric, backend: str): 

105 """ 

106 Converts a numpy array to the desired backend. 

107 

108 :param value: 

109 A numpy array. 

110 :param backend: 

111 The backend to use, one of the strings "tensorflow", "torch", "numpy", "jax". 

112 

113 :raises ValueError: 

114 If the backend is not recognized. 

115 

116 :return: 

117 The array `value` converted to the desired backend. 

118 """ 

119 if backend == "tensorflow": 

120 import tensorflow as tf 

121 

122 return tf.convert_to_tensor(value) 

123 elif backend in ["torch", "pytorch"]: 

124 import torch 

125 

126 return torch.tensor(value) 

127 elif backend == "numpy": 

128 return value 

129 elif backend == "jax": 

130 import jax.numpy as jnp 

131 

132 return jnp.array(value) 

133 elif backend == "scipy_sparse": 

134 import scipy.sparse as sp 

135 

136 return sp.csr_array(value) 

137 else: 

138 raise ValueError("Unknown backend: {}".format(backend)) 

139 

140 

141def create_random_state(backend: str, seed: int = 0): 

142 dtype = B.dtype(np_to_backend(np.array([[1.0]]), backend)) 

143 return B.create_random_state(dtype, seed=seed) 

144 

145 

146def array_type(backend: str): 

147 """ 

148 Returns the array type corresponding to the given backend. 

149 

150 :param backend: 

151 The backend to use, one of the strings "tensorflow", "torch", "numpy", 

152 "jax", "scipy_sparse". 

153 

154 :return: 

155 The array type corresponding to the given backend. 

156 """ 

157 if backend == "tensorflow": 

158 return resolve_type_hint(Union[B.TFNumeric, EagerTensor]) 

159 elif backend in ["torch", "pytorch"]: 

160 return resolve_type_hint(B.TorchNumeric) 

161 elif backend == "numpy": 

162 return resolve_type_hint(B.NPNumeric) 

163 elif backend == "jax": 

164 return resolve_type_hint(B.JAXNumeric) 

165 elif backend == "scipy_sparse": 

166 return resolve_type_hint(SparseArray) 

167 else: 

168 raise ValueError(f"Unknown backend: {backend}") 

169 

170 

171def apply_recursive(data: Any, func: Callable[[Any], Any]) -> Any: 

172 """ 

173 Apply a function recursively to a nested data structure. Supports lists and 

174 dictionaries. 

175 

176 :param data: 

177 The data structure to apply the function to. 

178 :param func: 

179 The function to apply. 

180 

181 :return: 

182 The data structure with the function applied to each element. 

183 """ 

184 if isinstance(data, dict): 

185 return {key: apply_recursive(value, func) for key, value in data.items()} 

186 elif isinstance(data, list): 

187 return [apply_recursive(element, func) for element in data] 

188 else: 

189 return func(data) 

190 

191 

192def check_function_with_backend( 

193 backend: str, 

194 result: Any, 

195 f: Callable, 

196 *args: Any, 

197 compare_to_result: Optional[Callable] = None, 

198 atol=1e-4, 

199): 

200 """ 

201 1. Casts the arguments `*args` to the backend `backend`. 

202 2. Runs the function `f` on the casted arguments. 

203 3. Checks that the result is of the backend `backend`. 

204 4. If no `compare_to_result` kwarg is provided, checks that the result, 

205 casted back to numpy backend, coincides with the given `result`. 

206 If `compare_to_result` is provided, checks if 

207 `compare_to_result(result, f_output)` is True. 

208 

209 :param backend: 

210 The backend to use, one of the strings "tensorflow", "torch", "numpy", "jax". 

211 :param result: 

212 The expected result of the function, if no `compare_to_result` kwarg is 

213 provided, expected to be a numpy array. Otherwise, can be anything. 

214 :param f: 

215 The backend-independent function to run. 

216 :param args: 

217 The arguments to pass to the function `f`, expected to be numpy arrays 

218 or non-array arguments. 

219 :param compare_to_result: 

220 A function that takes two arguments, the computed result and the 

221 expected result, and returns a boolean. 

222 :param atol: 

223 The absolute tolerance to use when comparing the computed result with 

224 the expected result. 

225 """ 

226 

227 def cast(arg): 

228 if is_bearable(arg, B.Numeric): 

229 # We only expect numpy arrays here 

230 die_if_unbearable(arg, B.NPNumeric) 

231 return np_to_backend(arg, backend) 

232 else: 

233 return arg 

234 

235 args_casted = (apply_recursive(arg, cast) for arg in args) 

236 

237 f_output = f(*args_casted) 

238 assert is_bearable( 

239 f_output, array_type(backend) 

240 ), f"The output is not of the expected type. Expected: {array_type(backend)}, got: {type(f_output)}" 

241 if compare_to_result is None: 

242 # we convert `f_output` to numpy array to compare with `result`` 

243 if is_bearable(f_output, SparseArray): 

244 f_output = f_output.toarray() 

245 else: 

246 f_output = B.to_numpy(f_output) 

247 assert ( 

248 result.shape == f_output.shape 

249 ), f"Shapes do not match: {result.shape} (for result) vs {f_output.shape} (for f_output)" 

250 np.testing.assert_allclose(f_output, result, atol=atol) 

251 else: 

252 assert compare_to_result( 

253 result, f_output 

254 ), f"compare_to_result(result, f_output) failed with\n result:\n{result}\n\nf_output:\n{f_output}"