Coverage for geometric_kernels/utils/utils.py: 64%

163 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2Convenience utilities. 

3""" 

4 

5import inspect 

6import sys 

7from contextlib import contextmanager 

8from importlib import resources as impresources 

9from itertools import combinations 

10 

11import einops 

12import lab as B 

13import numpy as np 

14from beartype.typing import Callable, Generator, List, Set, Tuple 

15 

16from geometric_kernels import resources 

17from geometric_kernels.lab_extras import ( 

18 count_nonzero, 

19 get_random_state, 

20 logical_xor, 

21 restore_random_state, 

22) 

23 

24 

25def chain(elements: B.Numeric, repetitions: List[int]) -> B.Numeric: 

26 """ 

27 Repeats each element in `elements` by a certain number of repetitions as 

28 specified in `repetitions`. The length of `elements` and `repetitions` 

29 should match. 

30 

31 :param elements: 

32 An [N,]-shaped array of elements to repeat. 

33 :param repetitions: 

34 A list specifying the number of types to repeat each of the elements in 

35 `elements`. The length of `repetitions` should be equal to N. 

36 

37 :return: 

38 An [M,]-shaped array. 

39 

40 EXAMPLE: 

41 

42 .. code-block:: python 

43 

44 elements = np.array([1, 2, 3]) 

45 repetitions = [2, 1, 3] 

46 out = chain(elements, repetitions) 

47 print(out) # [1, 1, 2, 3, 3, 3] 

48 """ 

49 values = [ 

50 einops.repeat(elements[i : i + 1], "j -> (tile j)", tile=repetitions[i]) 

51 for i in range(len(repetitions)) 

52 ] 

53 return B.concat(*values, axis=0) 

54 

55 

56def make_deterministic(f: Callable, key: B.RandomState) -> Callable: 

57 """ 

58 Returns a deterministic version of a function that uses a random 

59 number generator. 

60 

61 :param f: 

62 The function to make deterministic. 

63 :param key: 

64 The key used to generate the random state. 

65 

66 :return: 

67 A function representing the deterministic version of the input function. 

68 

69 .. note:: 

70 This function assumes that the input function has a 'key' argument 

71 or keyword-only argument that is used to generate random numbers. 

72 Otherwise, the function is returned as is. 

73 

74 EXAMPLE: 

75 

76 .. code-block:: python 

77 

78 key = tf.random.Generator.from_seed(1234) 

79 feature_map = default_feature_map(kernel=base_kernel) 

80 sample_paths = make_deterministic(sampler(feature_map), key) 

81 _, ys_train = sample_paths(xs_train, params) 

82 key, ys_test = sample_paths(xs_test, params) 

83 """ 

84 f_argspec = inspect.getfullargspec(f) 

85 f_varnames = f_argspec.args 

86 key_argtype = None 

87 if "key" in f_varnames: 

88 key_argtype = "pos" 

89 key_position = f_varnames.index("key") 

90 elif "key" in f_argspec.kwonlyargs: 

91 key_argtype = "kwonly" 

92 

93 if key_argtype is None: 

94 return f # already deterministic 

95 

96 saved_random_state = get_random_state(key) 

97 

98 def deterministic_f(*args, **kwargs): 

99 restored_key = restore_random_state(key, saved_random_state) 

100 if key_argtype == "kwonly": 

101 kwargs["key"] = restored_key 

102 new_args = args 

103 elif key_argtype == "pos": 

104 new_args = args[:key_position] + (restored_key,) + args[key_position:] 

105 else: 

106 raise ValueError("Unknown key_argtype %s" % key_argtype) 

107 return f(*new_args, **kwargs) 

108 

109 if hasattr(f, "__name__"): 

110 f_name = f.__name__ 

111 elif hasattr(f, "__class__") and hasattr(f.__class__, "__name__"): 

112 f_name = f.__class__.__name__ 

113 else: 

114 f_name = "<anonymous>" 

115 

116 if hasattr(f, "__doc__"): 

117 f_doc = f.__doc__ 

118 else: 

119 f_doc = "" 

120 

121 new_docstring = f""" 

122 This is a deterministic version of the function {f_name}. 

123 

124 The original docstring follows. 

125 

126 {f_doc} 

127 """ 

128 deterministic_f.__doc__ = new_docstring 

129 

130 return deterministic_f 

131 

132 

133def ordered_pairwise_differences(X: B.Numeric) -> B.Numeric: 

134 """ 

135 Compute the ordered pairwise differences between elements of a vector. 

136 

137 :param X: 

138 A [..., D]-shaped array, a batch of D-dimensional vectors. 

139 

140 :return: 

141 A [..., C]-shaped array, where C = D*(D-1)//2, containing the ordered 

142 pairwise differences between the elements of X. That is, the array 

143 containing differences X[...,i] - X[...,j] for all i < j. 

144 """ 

145 diffX = B.expand_dims(X, -2) - B.expand_dims(X, -1) # [B, D, D] 

146 # diffX[i, j] = X[j] - X[i] 

147 # lower triangle is i > j 

148 # so, lower triangle is X[k] - X[l] with k < l 

149 

150 diffX = B.tril_to_vec(diffX, offset=-1) # don't take the diagonal 

151 

152 return diffX 

153 

154 

155def fixed_length_partitions( # noqa: C901 

156 n: int, L: int 

157) -> Generator[List[int], None, None]: 

158 """ 

159 A generator for integer partitions of n into L parts, in colex order. 

160 

161 :param n: 

162 The number to partition. 

163 :param L: 

164 Size of partitions. 

165 

166 Developed by D. Eppstein in 2005, taken from 

167 https://www.ics.uci.edu/~eppstein/PADS/IntegerPartitions.py 

168 The algorithm follows Knuth v4 fasc3 p38 in rough outline; 

169 Knuth credits it to Hindenburg, 1779. 

170 """ 

171 

172 # guard against special cases 

173 if L == 0: 

174 if n == 0: 

175 yield [] 

176 return 

177 if L == 1: 

178 if n > 0: 

179 yield [n] 

180 return 

181 if n < L: 

182 return 

183 

184 partition = [n - L + 1] + (L - 1) * [1] 

185 while True: 

186 yield partition.copy() 

187 if partition[0] - 1 > partition[1]: 

188 partition[0] -= 1 

189 partition[1] += 1 

190 continue 

191 j = 2 

192 s = partition[0] + partition[1] - 1 

193 while j < L and partition[j] >= partition[0] - 1: 

194 s += partition[j] 

195 j += 1 

196 if j >= L: 

197 return 

198 partition[j] = x = partition[j] + 1 

199 j -= 1 

200 while j > 0: 

201 partition[j] = x 

202 s -= x 

203 j -= 1 

204 partition[0] = s 

205 

206 

207def partition_dominance_cone(partition: Tuple[int, ...]) -> Set[Tuple[int, ...]]: 

208 """ 

209 Calculates partitions dominated by a given one and having the same number 

210 of parts (including the zero parts of the original). 

211 

212 :param partition: 

213 A partition. 

214 

215 :return: 

216 A set of partitions. 

217 """ 

218 cone = {partition} 

219 new_partitions: Set[Tuple[int, ...]] = {(0,)} 

220 prev_partitions = cone 

221 while new_partitions: 

222 new_partitions = set() 

223 for partition in prev_partitions: 

224 for i in range(len(partition) - 1): 

225 if partition[i] > partition[i + 1]: 

226 for j in range(i + 1, len(partition)): 

227 if ( 

228 partition[i] > partition[j] + 1 

229 and partition[j] < partition[j - 1] 

230 ): 

231 new_partition = list(partition) 

232 new_partition[i] -= 1 

233 new_partition[j] += 1 

234 new_partition = tuple(new_partition) 

235 if new_partition not in cone: 

236 new_partitions.add(new_partition) 

237 cone.update(new_partitions) 

238 prev_partitions = new_partitions 

239 return cone 

240 

241 

242def partition_dominance_or_subpartition_cone( 

243 partition: Tuple[int, ...], 

244) -> Set[Tuple[int, ...]]: 

245 """ 

246 Calculates subpartitions and partitions dominated by a given one and having 

247 the same number of parts (including zero parts of the original). 

248 

249 :param partition: 

250 A partition. 

251 

252 :return: 

253 A set of partitions. 

254 """ 

255 cone = {partition} 

256 new_partitions: Set[Tuple[int, ...]] = {(0,)} 

257 prev_partitions = cone 

258 while new_partitions: 

259 new_partitions = set() 

260 for partition in prev_partitions: 

261 for i in range(len(partition) - 1): 

262 if partition[i] > partition[i + 1]: 

263 new_partition = list(partition) 

264 new_partition[i] -= 1 

265 new_partition = tuple(new_partition) 

266 if new_partition not in cone: 

267 new_partitions.add(new_partition) 

268 for j in range(i + 1, len(partition)): 

269 if ( 

270 partition[i] > partition[j] + 1 

271 and partition[j] < partition[j - 1] 

272 ): 

273 new_partition = list(partition) # type: ignore[assignment] 

274 new_partition[i] -= 1 # type: ignore[index] 

275 new_partition[j] += 1 # type: ignore[index] 

276 new_partition = tuple(new_partition) 

277 if new_partition not in cone: 

278 new_partitions.add(new_partition) 

279 cone.update(new_partitions) 

280 prev_partitions = new_partitions 

281 return cone 

282 

283 

284@contextmanager 

285def get_resource_file_path(filename: str): 

286 """ 

287 A contextmanager wrapper around `impresources` that supports both 

288 Python>=3.9 and Python==3.8 with a unified interface. 

289 

290 :param filename: 

291 The name of the file. 

292 """ 

293 if sys.version_info >= (3, 9): 

294 with impresources.as_file(impresources.files(resources) / filename) as path: 

295 yield path 

296 else: 

297 with impresources.path(resources, filename) as path: 

298 yield path 

299 

300 

301def hamming_distance(x1: B.Bool, x2: B.Bool): 

302 """ 

303 Hamming distance between two batches of boolean vectors. 

304 

305 :param x1: 

306 Array of any backend, of shape [N, D]. 

307 :param x2: 

308 Array of any backend, of shape [M, D]. 

309 

310 :return: 

311 An array of shape [N, M] whose entry n, m contains the Hamming distance 

312 between x1[n, :] and x2[m, :]. It is of the same backend as x1 and x2. 

313 """ 

314 return count_nonzero(logical_xor(x1[:, None, :], x2[None, :, :]), axis=-1) 

315 

316 

317def log_binomial(n: B.Int, k: B.Int) -> B.Float: 

318 """ 

319 Compute the logarithm of the binomial coefficient. 

320 

321 :param n: 

322 The number of elements in the set. 

323 :param k: 

324 The number of elements to choose. 

325 

326 :return: 

327 The logarithm of the binomial coefficient binom(n, k). 

328 """ 

329 if not B.all(0 <= k <= n): 

330 raise ValueError("Incorrect parameters of the binomial coefficient.") 

331 

332 return B.loggamma(n + 1) - B.loggamma(k + 1) - B.loggamma(n - k + 1) 

333 

334 

335def binary_vectors_and_subsets(d: int): 

336 r""" 

337 Generates all possible binary vectors of size d and all possible subsets of 

338 the set $\{0, .., d-1\}$ as a byproduct. 

339 

340 :param d: 

341 The dimension of binary vectors and the size of the set to take subsets of. 

342 

343 :return: 

344 A tuple (x, combs), where x is a matrix of size (2**d, d) whose rows 

345 are all possible binary vectors of size d, and combs is a list of all 

346 possible subsets of the set $\{0, .., d-1\}$, each subset being 

347 represented by a list of integers itself. 

348 """ 

349 x = np.zeros((2**d, d), dtype=bool) 

350 combs = [] 

351 i = 0 

352 for level in range(d + 1): 

353 for cur_combination in combinations(range(d), level): 

354 indices = list(cur_combination) 

355 combs.append(indices) 

356 x[i, indices] = 1 

357 i += 1 

358 

359 return x, combs 

360 

361 

362def _check_field_in_params(params, field): 

363 """ 

364 Raise an error if `params` does not contain a `field`. 

365 """ 

366 if field not in params: 

367 raise ValueError(f"`params` must contain `{field}`.") 

368 

369 

370def _check_1_vector(x, desc): 

371 """ 

372 Raise an error if `x` is not a vector of shape [1,]. 

373 """ 

374 if B.shape(x) != (1,): 

375 raise ValueError( 

376 f"`{desc}` must have shape `[1,]`, but has shape {B.shape(x)}." 

377 ) 

378 

379 

380def _check_rank_1_array(x, desc): 

381 """ 

382 Raise an error if `x` is not a rank-1 array. 

383 """ 

384 if B.rank(x) != 1: 

385 raise ValueError( 

386 f"`{desc}` must have 1 dimension (`ndim` == 1), but has shape {B.shape(x)}." 

387 ) 

388 

389 

390def _check_matrix(x, desc): 

391 """ 

392 Raise an error if `x` is not a matrix. 

393 """ 

394 if B.rank(x) != 2: 

395 raise ValueError(f"`{desc}` must be a matrix, but has shape {B.shape(x)}.")