Coverage for tests / helper.py: 96%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-15 13:41 +0000

1import lab as B 

2import numpy as np 

3from beartype.door import die_if_unbearable, is_bearable 

4from beartype.typing import Any, Callable, List, Optional, Union 

5from plum import ModuleType, resolve_type_hint 

6 

7from geometric_kernels.lab_extras import SparseArray 

8from geometric_kernels.spaces import ( 

9 Circle, 

10 CompactMatrixLieGroup, 

11 DiscreteSpectrumSpace, 

12 Graph, 

13 GraphEdges, 

14 HammingGraph, 

15 HodgeDiscreteSpectrumSpace, 

16 Hyperbolic, 

17 HypercubeGraph, 

18 Hypersphere, 

19 Mesh, 

20 NoncompactSymmetricSpace, 

21 ProductDiscreteSpectrumSpace, 

22 Space, 

23 SpecialOrthogonal, 

24 SpecialUnitary, 

25 SymmetricPositiveDefiniteMatrices, 

26) 

27 

28from .data import ( 

29 TEST_GRAPH_ADJACENCY, 

30 TEST_GRAPH_EDGES_NUM_NODES, 

31 TEST_GRAPH_EDGES_ORIENTED_EDGES, 

32 TEST_GRAPH_EDGES_ORIENTED_TRIANGLES, 

33 TEST_MESH_PATH, 

34) 

35 

36EagerTensor = ModuleType("tensorflow.python.framework.ops", "EagerTensor") 

37 

38 

39def compact_matrix_lie_groups() -> List[CompactMatrixLieGroup]: 

40 return [ 

41 SpecialOrthogonal(3), 

42 SpecialOrthogonal(8), 

43 SpecialUnitary(2), 

44 SpecialUnitary(5), 

45 ] 

46 

47 

48def product_discrete_spectrum_spaces() -> List[ProductDiscreteSpectrumSpace]: 

49 return [ 

50 ProductDiscreteSpectrumSpace(Circle(), Hypersphere(3), Circle()), 

51 ProductDiscreteSpectrumSpace( 

52 Circle(), Graph(np.kron(TEST_GRAPH_ADJACENCY, TEST_GRAPH_ADJACENCY)) 

53 ), # TEST_GRAPH_ADJACENCY is too small for default parameters of the ProductDiscreteSpectrumSpace 

54 ProductDiscreteSpectrumSpace(Mesh.load_mesh(TEST_MESH_PATH), Hypersphere(2)), 

55 ] 

56 

57 

58def hodge_discrete_spectrum_spaces() -> List[HodgeDiscreteSpectrumSpace]: 

59 return [ 

60 GraphEdges( 

61 TEST_GRAPH_EDGES_NUM_NODES, 

62 TEST_GRAPH_EDGES_ORIENTED_EDGES, 

63 TEST_GRAPH_EDGES_ORIENTED_TRIANGLES, 

64 ), 

65 ] 

66 

67 

68def discrete_spectrum_spaces() -> List[DiscreteSpectrumSpace]: 

69 return ( 

70 [ 

71 Circle(), 

72 HypercubeGraph(1), 

73 HypercubeGraph(3), 

74 HypercubeGraph(6), 

75 HammingGraph(1, 2), 

76 HammingGraph(3, 3), 

77 HammingGraph(2, 6), 

78 HammingGraph(6, 4), 

79 Hypersphere(2), 

80 Hypersphere(3), 

81 Hypersphere(10), 

82 Mesh.load_mesh(TEST_MESH_PATH), 

83 Graph(TEST_GRAPH_ADJACENCY, normalize_laplacian=False), 

84 Graph(TEST_GRAPH_ADJACENCY, normalize_laplacian=True), 

85 ] 

86 + compact_matrix_lie_groups() 

87 + product_discrete_spectrum_spaces() 

88 + hodge_discrete_spectrum_spaces() 

89 ) 

90 

91 

92def noncompact_symmetric_spaces() -> List[NoncompactSymmetricSpace]: 

93 return [ 

94 Hyperbolic(2), 

95 Hyperbolic(3), 

96 Hyperbolic(8), 

97 Hyperbolic(9), 

98 SymmetricPositiveDefiniteMatrices(2), 

99 SymmetricPositiveDefiniteMatrices(3), 

100 SymmetricPositiveDefiniteMatrices(6), 

101 SymmetricPositiveDefiniteMatrices(7), 

102 ] 

103 

104 

105def spaces() -> List[Space]: 

106 return discrete_spectrum_spaces() + noncompact_symmetric_spaces() 

107 

108 

109def np_to_backend(value: B.NPNumeric, backend: str): 

110 """ 

111 Converts a numpy array to the desired backend. 

112 

113 :param value: 

114 A numpy array. 

115 :param backend: 

116 The backend to use, one of the strings "tensorflow", "torch", "numpy", "jax". 

117 

118 :raises ValueError: 

119 If the backend is not recognized. 

120 

121 :return: 

122 The array `value` converted to the desired backend. 

123 """ 

124 if backend == "tensorflow": 

125 import tensorflow as tf 

126 

127 return tf.convert_to_tensor(value) 

128 elif backend in ["torch", "pytorch"]: 

129 import torch 

130 

131 return torch.tensor(value) 

132 elif backend == "numpy": 

133 return value 

134 elif backend == "jax": 

135 import jax.numpy as jnp 

136 

137 return jnp.array(value) 

138 elif backend == "scipy_sparse": 

139 import scipy.sparse as sp 

140 

141 return sp.csr_array(value) 

142 else: 

143 raise ValueError("Unknown backend: {}".format(backend)) 

144 

145 

146def create_random_state(backend: str, seed: int = 0): 

147 dtype = B.dtype(np_to_backend(np.array([[1.0]]), backend)) 

148 return B.create_random_state(dtype, seed=seed) 

149 

150 

151def array_type(backend: str): 

152 """ 

153 Returns the array type corresponding to the given backend. 

154 

155 :param backend: 

156 The backend to use, one of the strings "tensorflow", "torch", "numpy", 

157 "jax", "scipy_sparse". 

158 

159 :return: 

160 The array type corresponding to the given backend. 

161 """ 

162 if backend == "tensorflow": 

163 return resolve_type_hint(Union[B.TFNumeric, EagerTensor]) 

164 elif backend in ["torch", "pytorch"]: 

165 return resolve_type_hint(B.TorchNumeric) 

166 elif backend == "numpy": 

167 return resolve_type_hint(B.NPNumeric) 

168 elif backend == "jax": 

169 return resolve_type_hint(B.JAXNumeric) 

170 elif backend == "scipy_sparse": 

171 return resolve_type_hint(SparseArray) 

172 else: 

173 raise ValueError(f"Unknown backend: {backend}") 

174 

175 

176def apply_recursive(data: Any, func: Callable[[Any], Any]) -> Any: 

177 """ 

178 Apply a function recursively to a nested data structure. Supports lists and 

179 dictionaries. 

180 

181 :param data: 

182 The data structure to apply the function to. 

183 :param func: 

184 The function to apply. 

185 

186 :return: 

187 The data structure with the function applied to each element. 

188 """ 

189 if isinstance(data, dict): 

190 return {key: apply_recursive(value, func) for key, value in data.items()} 

191 elif isinstance(data, list): 

192 return [apply_recursive(element, func) for element in data] 

193 else: 

194 return func(data) 

195 

196 

197def check_function_with_backend( 

198 backend: str, 

199 result: Any, 

200 f: Callable, 

201 *args: Any, 

202 compare_to_result: Optional[Callable] = None, 

203 atol=1e-4, 

204): 

205 """ 

206 1. Casts the arguments `*args` to the backend `backend`. 

207 2. Runs the function `f` on the casted arguments. 

208 3. Checks that the result is of the backend `backend`. 

209 4. If no `compare_to_result` kwarg is provided, checks that the result, 

210 casted back to numpy backend, coincides with the given `result`. 

211 If `compare_to_result` is provided, checks if 

212 `compare_to_result(result, f_output)` is True. 

213 

214 :param backend: 

215 The backend to use, one of the strings "tensorflow", "torch", "numpy", "jax". 

216 :param result: 

217 The expected result of the function, if no `compare_to_result` kwarg is 

218 provided, expected to be a numpy array. Otherwise, can be anything. 

219 :param f: 

220 The backend-independent function to run. 

221 :param args: 

222 The arguments to pass to the function `f`, expected to be numpy arrays 

223 or non-array arguments. 

224 :param compare_to_result: 

225 A function that takes two arguments, the computed result and the 

226 expected result, and returns a boolean. 

227 :param atol: 

228 The absolute tolerance to use when comparing the computed result with 

229 the expected result. 

230 """ 

231 

232 def cast(arg): 

233 if is_bearable(arg, B.Numeric): 

234 # We only expect numpy arrays here 

235 die_if_unbearable(arg, B.NPNumeric) 

236 return np_to_backend(arg, backend) 

237 else: 

238 return arg 

239 

240 args_casted = (apply_recursive(arg, cast) for arg in args) 

241 

242 f_output = f(*args_casted) 

243 assert is_bearable( 

244 f_output, array_type(backend) 

245 ), f"The output is not of the expected type. Expected: {array_type(backend)}, got: {type(f_output)}" 

246 if compare_to_result is None: 

247 # we convert `f_output` to numpy array to compare with `result`` 

248 if is_bearable(f_output, SparseArray): 

249 f_output = f_output.toarray() 

250 else: 

251 f_output = B.to_numpy(f_output) 

252 assert ( 

253 result.shape == f_output.shape 

254 ), f"Shapes do not match: {result.shape} (for result) vs {f_output.shape} (for f_output)" 

255 np.testing.assert_allclose(f_output, result, atol=atol) 

256 else: 

257 assert compare_to_result( 

258 result, f_output 

259 ), f"compare_to_result(result, f_output) failed with\n result:\n{result}\n\nf_output:\n{f_output}"