Coverage for geometric_kernels / utils / utils.py: 64%

163 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-15 13:41 +0000

1""" 

2Convenience utilities. 

3""" 

4 

5import inspect 

6import sys 

7from contextlib import contextmanager 

8from importlib import resources as impresources 

9from itertools import combinations 

10 

11import einops 

12import lab as B 

13import numpy as np 

14from beartype.typing import Callable, Generator, List, Set, Tuple 

15 

16from geometric_kernels import resources 

17from geometric_kernels.lab_extras import ( 

18 count_nonzero, 

19 get_random_state, 

20 restore_random_state, 

21) 

22 

23 

24def chain(elements: B.Numeric, repetitions: List[int]) -> B.Numeric: 

25 """ 

26 Repeats each element in `elements` by a certain number of repetitions as 

27 specified in `repetitions`. The length of `elements` and `repetitions` 

28 should match. 

29 

30 :param elements: 

31 An [N,]-shaped array of elements to repeat. 

32 :param repetitions: 

33 A list specifying the number of types to repeat each of the elements in 

34 `elements`. The length of `repetitions` should be equal to N. 

35 

36 :return: 

37 An [M,]-shaped array. 

38 

39 EXAMPLE: 

40 

41 .. code-block:: python 

42 

43 elements = np.array([1, 2, 3]) 

44 repetitions = [2, 1, 3] 

45 out = chain(elements, repetitions) 

46 print(out) # [1, 1, 2, 3, 3, 3] 

47 """ 

48 values = [ 

49 einops.repeat(elements[i : i + 1], "j -> (tile j)", tile=repetitions[i]) 

50 for i in range(len(repetitions)) 

51 ] 

52 return B.concat(*values, axis=0) 

53 

54 

55def make_deterministic(f: Callable, key: B.RandomState) -> Callable: 

56 """ 

57 Returns a deterministic version of a function that uses a random 

58 number generator. 

59 

60 :param f: 

61 The function to make deterministic. 

62 :param key: 

63 The key used to generate the random state. 

64 

65 :return: 

66 A function representing the deterministic version of the input function. 

67 

68 .. note:: 

69 This function assumes that the input function has a 'key' argument 

70 or keyword-only argument that is used to generate random numbers. 

71 Otherwise, the function is returned as is. 

72 

73 EXAMPLE: 

74 

75 .. code-block:: python 

76 

77 key = tf.random.Generator.from_seed(1234) 

78 feature_map = default_feature_map(kernel=base_kernel) 

79 sample_paths = make_deterministic(sampler(feature_map), key) 

80 _, ys_train = sample_paths(xs_train, params) 

81 key, ys_test = sample_paths(xs_test, params) 

82 """ 

83 f_argspec = inspect.getfullargspec(f) 

84 f_varnames = f_argspec.args 

85 key_argtype = None 

86 if "key" in f_varnames: 

87 key_argtype = "pos" 

88 key_position = f_varnames.index("key") 

89 elif "key" in f_argspec.kwonlyargs: 

90 key_argtype = "kwonly" 

91 

92 if key_argtype is None: 

93 return f # already deterministic 

94 

95 saved_random_state = get_random_state(key) 

96 

97 def deterministic_f(*args, **kwargs): 

98 restored_key = restore_random_state(key, saved_random_state) 

99 if key_argtype == "kwonly": 

100 kwargs["key"] = restored_key 

101 new_args = args 

102 elif key_argtype == "pos": 

103 new_args = args[:key_position] + (restored_key,) + args[key_position:] 

104 else: 

105 raise ValueError("Unknown key_argtype %s" % key_argtype) 

106 return f(*new_args, **kwargs) 

107 

108 if hasattr(f, "__name__"): 

109 f_name = f.__name__ 

110 elif hasattr(f, "__class__") and hasattr(f.__class__, "__name__"): 

111 f_name = f.__class__.__name__ 

112 else: 

113 f_name = "<anonymous>" 

114 

115 if hasattr(f, "__doc__"): 

116 f_doc = f.__doc__ 

117 else: 

118 f_doc = "" 

119 

120 new_docstring = f""" 

121 This is a deterministic version of the function {f_name}. 

122 

123 The original docstring follows. 

124 

125 {f_doc} 

126 """ 

127 deterministic_f.__doc__ = new_docstring 

128 

129 return deterministic_f 

130 

131 

132def ordered_pairwise_differences(X: B.Numeric) -> B.Numeric: 

133 """ 

134 Compute the ordered pairwise differences between elements of a vector. 

135 

136 :param X: 

137 A [..., D]-shaped array, a batch of D-dimensional vectors. 

138 

139 :return: 

140 A [..., C]-shaped array, where C = D*(D-1)//2, containing the ordered 

141 pairwise differences between the elements of X. That is, the array 

142 containing differences X[...,i] - X[...,j] for all i < j. 

143 """ 

144 diffX = B.expand_dims(X, -2) - B.expand_dims(X, -1) # [B, D, D] 

145 # diffX[i, j] = X[j] - X[i] 

146 # lower triangle is i > j 

147 # so, lower triangle is X[k] - X[l] with k < l 

148 

149 diffX = B.tril_to_vec(diffX, offset=-1) # don't take the diagonal 

150 

151 return diffX 

152 

153 

154def fixed_length_partitions( # noqa: C901 

155 n: int, L: int 

156) -> Generator[List[int], None, None]: 

157 """ 

158 A generator for integer partitions of n into L parts, in colex order. 

159 

160 :param n: 

161 The number to partition. 

162 :param L: 

163 Size of partitions. 

164 

165 Developed by D. Eppstein in 2005, taken from 

166 https://www.ics.uci.edu/~eppstein/PADS/IntegerPartitions.py 

167 The algorithm follows Knuth v4 fasc3 p38 in rough outline; 

168 Knuth credits it to Hindenburg, 1779. 

169 """ 

170 

171 # guard against special cases 

172 if L == 0: 

173 if n == 0: 

174 yield [] 

175 return 

176 if L == 1: 

177 if n > 0: 

178 yield [n] 

179 return 

180 if n < L: 

181 return 

182 

183 partition = [n - L + 1] + (L - 1) * [1] 

184 while True: 

185 yield partition.copy() 

186 if partition[0] - 1 > partition[1]: 

187 partition[0] -= 1 

188 partition[1] += 1 

189 continue 

190 j = 2 

191 s = partition[0] + partition[1] - 1 

192 while j < L and partition[j] >= partition[0] - 1: 

193 s += partition[j] 

194 j += 1 

195 if j >= L: 

196 return 

197 partition[j] = x = partition[j] + 1 

198 j -= 1 

199 while j > 0: 

200 partition[j] = x 

201 s -= x 

202 j -= 1 

203 partition[0] = s 

204 

205 

206def partition_dominance_cone(partition: Tuple[int, ...]) -> Set[Tuple[int, ...]]: 

207 """ 

208 Calculates partitions dominated by a given one and having the same number 

209 of parts (including the zero parts of the original). 

210 

211 :param partition: 

212 A partition. 

213 

214 :return: 

215 A set of partitions. 

216 """ 

217 cone = {partition} 

218 new_partitions: Set[Tuple[int, ...]] = {(0,)} 

219 prev_partitions = cone 

220 while new_partitions: 

221 new_partitions = set() 

222 for partition in prev_partitions: 

223 for i in range(len(partition) - 1): 

224 if partition[i] > partition[i + 1]: 

225 for j in range(i + 1, len(partition)): 

226 if ( 

227 partition[i] > partition[j] + 1 

228 and partition[j] < partition[j - 1] 

229 ): 

230 new_partition = list(partition) 

231 new_partition[i] -= 1 

232 new_partition[j] += 1 

233 new_partition = tuple(new_partition) 

234 if new_partition not in cone: 

235 new_partitions.add(new_partition) 

236 cone.update(new_partitions) 

237 prev_partitions = new_partitions 

238 return cone 

239 

240 

241def partition_dominance_or_subpartition_cone( 

242 partition: Tuple[int, ...], 

243) -> Set[Tuple[int, ...]]: 

244 """ 

245 Calculates subpartitions and partitions dominated by a given one and having 

246 the same number of parts (including zero parts of the original). 

247 

248 :param partition: 

249 A partition. 

250 

251 :return: 

252 A set of partitions. 

253 """ 

254 cone = {partition} 

255 new_partitions: Set[Tuple[int, ...]] = {(0,)} 

256 prev_partitions = cone 

257 while new_partitions: 

258 new_partitions = set() 

259 for partition in prev_partitions: 

260 for i in range(len(partition) - 1): 

261 if partition[i] > partition[i + 1]: 

262 new_partition = list(partition) 

263 new_partition[i] -= 1 

264 new_partition = tuple(new_partition) 

265 if new_partition not in cone: 

266 new_partitions.add(new_partition) 

267 for j in range(i + 1, len(partition)): 

268 if ( 

269 partition[i] > partition[j] + 1 

270 and partition[j] < partition[j - 1] 

271 ): 

272 new_partition = list(partition) # type: ignore[assignment] 

273 new_partition[i] -= 1 # type: ignore[index] 

274 new_partition[j] += 1 # type: ignore[index] 

275 new_partition = tuple(new_partition) 

276 if new_partition not in cone: 

277 new_partitions.add(new_partition) 

278 cone.update(new_partitions) 

279 prev_partitions = new_partitions 

280 return cone 

281 

282 

283@contextmanager 

284def get_resource_file_path(filename: str): 

285 """ 

286 A contextmanager wrapper around `impresources` that supports both 

287 Python>=3.9 and Python==3.8 with a unified interface. 

288 

289 :param filename: 

290 The name of the file. 

291 """ 

292 if sys.version_info >= (3, 9): 

293 with impresources.as_file(impresources.files(resources) / filename) as path: 

294 yield path 

295 else: 

296 with impresources.path(resources, filename) as path: 

297 yield path 

298 

299 

300def hamming_distance(x1: B.Int, x2: B.Int): 

301 """ 

302 Hamming distance between two batches of boolean vectors. 

303 

304 :param x1: 

305 Array of any backend, of shape [N, D]. 

306 :param x2: 

307 Array of any backend, of shape [M, D]. 

308 

309 :return: 

310 An array of shape [N, M] whose entry n, m contains the Hamming distance 

311 between x1[n, :] and x2[m, :]. It is of the same backend as x1 and x2. 

312 """ 

313 return count_nonzero(x1[:, None, :] != x2[None, :, :], axis=-1) 

314 

315 

316def log_binomial(n: B.Int, k: B.Int) -> B.Float: 

317 """ 

318 Compute the logarithm of the binomial coefficient. 

319 

320 :param n: 

321 The number of elements in the set. 

322 :param k: 

323 The number of elements to choose. 

324 

325 :return: 

326 The logarithm of the binomial coefficient binom(n, k). 

327 """ 

328 if not B.all(0 <= k <= n): 

329 raise ValueError("Incorrect parameters of the binomial coefficient.") 

330 

331 return B.loggamma(n + 1) - B.loggamma(k + 1) - B.loggamma(n - k + 1) 

332 

333 

334def binary_vectors_and_subsets(d: int): 

335 r""" 

336 Generates all possible binary vectors of size d and all possible subsets of 

337 the set $\{0, .., d-1\}$ as a byproduct. 

338 

339 :param d: 

340 The dimension of binary vectors and the size of the set to take subsets of. 

341 

342 :return: 

343 A tuple (x, combs), where x is a matrix of size (2**d, d) whose rows 

344 are all possible binary vectors of size d, and combs is a list of all 

345 possible subsets of the set $\{0, .., d-1\}$, each subset being 

346 represented by a list of integers itself. 

347 """ 

348 x = np.zeros((2**d, d), dtype=bool) 

349 combs = [] 

350 i = 0 

351 for level in range(d + 1): 

352 for cur_combination in combinations(range(d), level): 

353 indices = list(cur_combination) 

354 combs.append(indices) 

355 x[i, indices] = 1 

356 i += 1 

357 

358 return x, combs 

359 

360 

361def _check_field_in_params(params, field): 

362 """ 

363 Raise an error if `params` does not contain a `field`. 

364 """ 

365 if field not in params: 

366 raise ValueError(f"`params` must contain `{field}`.") 

367 

368 

369def _check_1_vector(x, desc): 

370 """ 

371 Raise an error if `x` is not a vector of shape [1,]. 

372 """ 

373 if B.shape(x) != (1,): 

374 raise ValueError( 

375 f"`{desc}` must have shape `[1,]`, but has shape {B.shape(x)}." 

376 ) 

377 

378 

379def _check_rank_1_array(x, desc): 

380 """ 

381 Raise an error if `x` is not a rank-1 array. 

382 """ 

383 if B.rank(x) != 1: 

384 raise ValueError( 

385 f"`{desc}` must have 1 dimension (`ndim` == 1), but has shape {B.shape(x)}." 

386 ) 

387 

388 

389def _check_matrix(x, desc): 

390 """ 

391 Raise an error if `x` is not a matrix. 

392 """ 

393 if B.rank(x) != 2: 

394 raise ValueError(f"`{desc}` must be a matrix, but has shape {B.shape(x)}.")