Coverage for geometric_kernels/spaces/product.py: 94%

159 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`ProductDiscreteSpectrumSpace` space and the 

3respective :class:`~.eigenfunctions.Eigenfunctions` subclass 

4:class:`ProductEigenfunctions`. 

5 

6See :doc:`this page </theory/product_spaces>` for a brief account on 

7theory behind product spaces and the :doc:`Torus.ipynb </examples/Torus>` 

8notebook for a tutorial on how to use them. 

9""" 

10 

11import itertools 

12import math 

13 

14import lab as B 

15import numpy as np 

16from beartype.typing import List, Optional, Tuple 

17 

18from geometric_kernels.lab_extras import from_numpy, int_like, take_along_axis 

19from geometric_kernels.spaces.base import DiscreteSpectrumSpace 

20from geometric_kernels.spaces.eigenfunctions import Eigenfunctions 

21from geometric_kernels.utils.product import make_product, project_product 

22from geometric_kernels.utils.utils import chain 

23 

24 

25def _find_lowest_sum_combinations(arr: B.Numeric, k: int) -> B.Numeric: 

26 """ 

27 For an [N, D] array `arr`, with `arr[i, :]` assumed sorted, find `k` 

28 smallest sums of one element per each row, i.e. sums of form 

29 `sum_i arr[i, j_i]`, return the array of indices of the summands. 

30 

31 :param arr: 

32 An [N, D]-shaped array. 

33 :param k: 

34 Number of smallest sums to find. 

35 

36 :return: 

37 A [k, N]-shaped array of the same backend as `arr`. 

38 """ 

39 N, D = arr.shape 

40 

41 index_array = B.stack(*[B.zeros(int_like(arr), N) for i in range(k)], axis=0) 

42 # prebuild array for indexing the first index of the eigenvalue array 

43 first_index = B.linspace(0, N - 1, N).astype(int) 

44 i = 0 

45 

46 # first eigenvalue is the sum of the first eigenvalues of the individual spaces 

47 curr_idx = B.zeros(int, N)[None, :] 

48 while i < k: 

49 # compute eigenvalues of the proposals 

50 sum_values = arr[first_index, curr_idx].sum(axis=1) 

51 

52 # Compute tied smallest new eigenvalues 

53 lowest_sum_index = int(sum_values.argmin()) 

54 tied_sums = sum_values == sum_values[lowest_sum_index] 

55 tied_sums_indexes = B.linspace(0, len(tied_sums) - 1, len(tied_sums)).astype( 

56 int 

57 )[tied_sums] 

58 

59 # Add new eigenvalues to indexing array 

60 for index in tied_sums_indexes: 

61 index_array[i, :] = curr_idx[index] 

62 i += 1 

63 if i >= k: 

64 break 

65 

66 # keep unaccepted ones around 

67 old_indices = curr_idx[~tied_sums] 

68 # mutate just accepted ones by adding one to each eigenindex 

69 new_indices = curr_idx[tied_sums][..., None, :] + B.eye(int, curr_idx.shape[-1]) 

70 new_indices = new_indices.reshape((-1, new_indices.shape[-1])) 

71 new_indices = B.minimum(new_indices, D - 1) 

72 curr_idx = B.concat(old_indices, new_indices, axis=0) 

73 curr_idx = np.unique( 

74 B.to_numpy(curr_idx.reshape((-1, curr_idx.shape[-1]))), 

75 axis=0, 

76 ) 

77 # Filter out already searched combinations. 

78 # See https://stackoverflow.com/a/40056251. 

79 dims = np.maximum(curr_idx.max(0), index_array.max(0)) + 1 

80 curr_idx = curr_idx[ 

81 ~np.in1d( 

82 np.ravel_multi_index(curr_idx.T, dims), 

83 np.ravel_multi_index(index_array.T, dims), 

84 ) 

85 ] 

86 

87 return index_array 

88 

89 

90def _num_per_level_to_mapping(num_per_level: List[int]) -> List[List[int]]: 

91 """ 

92 Given a list of numbers of eigenfunctions per level, return a list `mapping` 

93 such that `mapping[i][j]` is the index of the `j`-th eigenfunction of the 

94 `i`-th level. 

95 

96 :param num_per_level: 

97 A `num_eigenfunctions_per_level` list of some space. 

98 """ 

99 mapping = [] 

100 i = 0 

101 for num in num_per_level: 

102 mapping.append([i + j for j in range(num)]) 

103 i += num 

104 return mapping 

105 

106 

107def _eigenlevelindices_to_eigenfunctionindices( 

108 eigenlevelindices: B.Numeric, nums_per_level: List[List[int]] 

109) -> B.Numeric: 

110 """ 

111 Given an `eigenlevelindices` which maps the product space's level indices to 

112 the indices of the factor spaces' levels, return an `eigenfunctionindices` 

113 array that maps the product space's eigenfunction index to the indices of 

114 the factor spaces' individual eigenfunctions. 

115 

116 :param eigenlevelindices: 

117 A [L, S]-shaped array, where `S` is the number of factor spaces and `L` 

118 is the number of levels, such that `eigenindicies[i, :]` are the indices 

119 of the levels of the factor spaces that correspond to the `i`'th level 

120 of the product space. 

121 :param nums_per_level: 

122 Each `nums_per_level[j]` represents the `num_eigenfunctions_per_level` 

123 list of the `j`-th factor space. 

124 

125 :return: 

126 A [J, S]-shaped array, where `J` is the total number of eigenfunctions. 

127 """ 

128 L = eigenlevelindices.shape[0] 

129 S = eigenlevelindices.shape[1] 

130 levels_to_eigenfunction_indices: List[List[List[int]]] = [ 

131 _num_per_level_to_mapping(npl) for npl in nums_per_level 

132 ] # S lists of lists of length L 

133 

134 eigenfunctionindices = [] 

135 for cur_level in range(L): 

136 cur_factor_eigenfunction_indices: List[List[int]] = [ 

137 levels_to_eigenfunction_indices[s][eigenlevelindices[cur_level, s]] 

138 for s in range(S) 

139 ] 

140 new_indices: List[Tuple[int, ...]] = list( 

141 itertools.product(*cur_factor_eigenfunction_indices) 

142 ) # part of the resulting eigenfunctionindices for the cur_level level. 

143 eigenfunctionindices += new_indices 

144 

145 out = from_numpy(eigenlevelindices, np.array(eigenfunctionindices)) 

146 return out 

147 

148 

149class ProductEigenfunctions(Eigenfunctions): 

150 r""" 

151 Eigenfunctions of the Laplacian on the product space are outer products 

152 of the eigenfunctions on factors. 

153 

154 Levels correspond to tuples of levels of the factors. 

155 

156 :param element_shapes: 

157 Shapes of the elements in each factor. Can be obtained as properties 

158 `space.element_shape` of any given factor `space`. 

159 :param element_dtypes: 

160 Abstract lab data types of the elements in each factor. Can be obtained 

161 as properties `space.element_dtype` of any given factor `space`. 

162 :param eigenindicies: 

163 A [L, S]-shaped array, where `S` is the number of factor spaces and `L` 

164 is the number of levels, such that `eigenindicies[i, :]` are the indices 

165 of the levels of the factor spaces that correspond to the `i`'th level 

166 of the product space. 

167 

168 This parameter implicitly determines the number of levels. 

169 :param ``*eigenfunctions``: 

170 The :class:`~.Eigenfunctions` subclasses corresponding to each of 

171 the factor spaces. 

172 :param dimension_indices: 

173 Determines how a vector `x` representing a product space element is to 

174 be mapped into the arrays `xi` that represent elements of the factor 

175 spaces. `xi` are assumed to be equal to `x[dimension_indices[i]]`, 

176 possibly up to a reshape. Such a reshape might be necessary to 

177 accommodate the spaces whose elements are matrices rather than vectors, 

178 as determined by `element_shapes`. The transformation of `x` into the 

179 list of `xi`\ s is performed by :func:`~.project_product`. 

180 

181 If None, assumes the each input is layed-out flattened and concatenated, 

182 in the same order as the factor spaces. In this case, the inverse to 

183 :func:`~.project_product` is :func:`~.make_product`. 

184 

185 Defaults to None. 

186 """ 

187 

188 def __init__( 

189 self, 

190 element_shapes: List[List[int]], 

191 element_dtypes: List[B.DType], 

192 eigenindicies: B.Numeric, 

193 *eigenfunctions: Eigenfunctions, 

194 dimension_indices: B.Numeric = None, 

195 ): 

196 self._num_levels = eigenindicies.shape[0] 

197 self.element_shapes = element_shapes 

198 self.element_dtypes = element_dtypes 

199 dimensions = [math.prod(element_shape) for element_shape in self.element_shapes] 

200 if dimension_indices is None: 

201 self.dimension_indices = [] 

202 i = 0 

203 for dim in dimensions: 

204 self.dimension_indices.append([*range(i, i + dim)]) 

205 i += dim 

206 self.eigenindicies = eigenindicies 

207 self.eigenfunctions = eigenfunctions 

208 

209 self.nums_per_level: List[List[int]] = [ 

210 eigenfunctions_factor.num_eigenfunctions_per_level 

211 for eigenfunctions_factor in self.eigenfunctions 

212 ] # [S, LFactor] where `LFactor` is eigenfunctions[0].num_levels 

213 

214 self._eigenfunctionindices = _eigenlevelindices_to_eigenfunctionindices( 

215 self.eigenindicies, self.nums_per_level 

216 ) 

217 

218 if self.eigenindicies.shape[-1] != len(self.eigenfunctions): 

219 raise ValueError( 

220 "Expected to have S `eigenfunctions` and `eigenindicies` of shape [L, S], " 

221 "where S is the number of spaces and L is the number of levels, " 

222 f"but got S1={len(self.eigenfunctions)} eigenfunctions and " 

223 f"the shape of `eigenindicies` is {self.eigenindicies.shape}, which is incompatible." 

224 ) 

225 

226 def __call__(self, X: B.Numeric, **kwargs) -> B.Numeric: 

227 """ 

228 Evaluate the individual eigenfunctions at a batch of input locations. 

229 

230 .. warning:: 

231 Will `raise NotImplementedError` if some of the factors do not 

232 implement their `__call__`. 

233 

234 :param X: 

235 Points to evaluate the eigenfunctions at, an array of shape [N, D], 

236 where `N` is the number of points and `D` is the dimension of the 

237 vectors that represent points in the product space. 

238 

239 Each point `x` in the product space is a `D`-dimensional vector such 

240 that `x[dimension_indices[i]]` is the vector that represents the 

241 flattened element of the `i`-th factor space. 

242 

243 .. note:: 

244 If the instance was created with `dimension_indices=None`, then 

245 you can use the :func:`~.make_product` function to convert a 

246 list of (batches of) points in factor spaces into a (batch of) 

247 points in the product space. 

248 

249 :param ``**kwargs``: 

250 Any additional parameters. 

251 

252 :return: 

253 An [N, J]-shaped array, where `J` is the number of eigenfunctions. 

254 """ 

255 Xs = project_product( 

256 X, self.dimension_indices, self.element_shapes, self.element_dtypes 

257 ) # List of S arrays, each of shape [N, *element_shape_s] 

258 

259 factor_eigenfunction_values = [ 

260 eigenfunction(X, **kwargs) # [N, Js], Js different for each factor 

261 for eigenfunction, X in zip(self.eigenfunctions, Xs) 

262 ] # List of length S, contains arrays of different shapes (not stackable) 

263 

264 eigenfunction_values = [] 

265 for j in range(self.num_eigenfunctions): 

266 m_idx = self._eigenfunctionindices[j] # (S,)-shaped array 

267 eigenfunction_values.append( 

268 B.prod( 

269 B.stack( 

270 *( 

271 factor_eigenfunction_values[s][ 

272 :, self._eigenfunctionindices[j, s] 

273 ] # [N,] 

274 for s in range(len(m_idx)) 

275 ), 

276 axis=-1, 

277 ), # [N, S] 

278 axis=1, 

279 ) # [N,] 

280 ) 

281 eigenfunction_values = B.stack(*eigenfunction_values, axis=-1) # [N, J] 

282 

283 return eigenfunction_values 

284 

285 @property 

286 def num_eigenfunctions(self) -> int: 

287 return self._eigenfunctionindices.shape[0] 

288 

289 @property 

290 def num_levels(self) -> int: 

291 return self._num_levels 

292 

293 def phi_product( 

294 self, X: B.Numeric, X2: Optional[B.Numeric] = None, **kwargs 

295 ) -> B.Numeric: 

296 if X2 is None: 

297 X2 = X 

298 Xs = project_product( 

299 X, self.dimension_indices, self.element_shapes, self.element_dtypes 

300 ) # List of S arrays, each of shape [N, *element_shape_s] 

301 Xs2 = project_product( 

302 X2, self.dimension_indices, self.element_shapes, self.element_dtypes 

303 ) # List of S arrays, each of shape [N2, *element_shape_s] 

304 

305 factor_phi_products = [ 

306 take_along_axis( 

307 eigenfunction.phi_product(X1, X2, **kwargs), # [N, N2, L_s] 

308 from_numpy(X1, self.eigenindicies[None, None, :, s]), 

309 -1, 

310 ) # [N, N2, L] 

311 for s, (eigenfunction, X1, X2) in enumerate( 

312 zip(self.eigenfunctions, Xs, Xs2) 

313 ) 

314 ] 

315 common_dtype = B.promote_dtypes(*[B.dtype(x) for x in factor_phi_products]) 

316 

317 phis = B.stack( 

318 *[B.cast(common_dtype, x) for x in factor_phi_products], axis=-1 

319 ) # [N, N2, L, S] 

320 

321 prod_phis = B.prod(phis, axis=-1) # [N, N2, L, S] -> [N, N2, L] 

322 

323 return prod_phis 

324 

325 def phi_product_diag(self, X: B.Numeric, **kwargs): 

326 Xs = project_product( 

327 X, self.dimension_indices, self.element_shapes, self.element_dtypes 

328 ) 

329 

330 phis = B.stack( 

331 *[ 

332 take_along_axis( 

333 eigenfunction.phi_product_diag(X1, **kwargs), # [N, L_s] 

334 from_numpy(X1, self.eigenindicies[None, :, s]), 

335 -1, 

336 ) # [N, L] 

337 for s, (eigenfunction, X1) in enumerate(zip(self.eigenfunctions, Xs)) 

338 ], 

339 axis=-1, 

340 ) # [N, L, S] 

341 

342 prod_phis = B.prod(phis, axis=-1) # [N, L] 

343 

344 return prod_phis 

345 

346 @property 

347 def num_eigenfunctions_per_level(self) -> List[int]: 

348 L = self.eigenindicies.shape[0] 

349 S = self.eigenindicies.shape[1] 

350 

351 totals = [] 

352 

353 for cur_level in range(L): 

354 totals.append( 

355 math.prod( 

356 self.nums_per_level[s][self.eigenindicies[cur_level, s]] 

357 for s in range(S) 

358 ) 

359 ) 

360 

361 return totals 

362 

363 

364class ProductDiscreteSpectrumSpace(DiscreteSpectrumSpace): 

365 r""" 

366 The GeometricKernels space representing a product 

367 

368 .. math:: \mathcal{S} = \mathcal{S}_1 \times \ldots \mathcal{S}_S 

369 

370 of :class:`~.spaces.DiscreteSpectrumSpace`-s $\mathcal{S}_i$. 

371 

372 Levels are indexed by tuples of levels of the factors. 

373 

374 .. admonition:: Precomputing optimal levels 

375 

376 The eigenvalue corresponding to a level of the product space 

377 represented by a tuple $(j_1, .., j_S)$ is the sum 

378 

379 .. math:: \lambda^{(1)}_{j_1} + \ldots + \lambda^{(S)}_{j_S} 

380 

381 of the eigenvalues $\lambda^{(s)}_{j_s}$ of the factors. 

382 

383 Computing top `num_levels` smallest eigenvalues is thus a combinatorial 

384 problem, the solution to which we precompute. To make this 

385 precomputation possible you need to provide the `num_levels` parameter 

386 when constructing the :class:`ProductDiscreteSpectrumSpace`, unlike for 

387 other spaces. What is more, the `num` parameter of the 

388 :meth:`get_eigenfunctions` cannot be larger than the `num_levels` value. 

389 

390 .. note:: 

391 See :doc:`this page </theory/product_spaces>` for a brief account on 

392 theory behind product spaces and the :doc:`Torus.ipynb 

393 </examples/Torus>` notebook for a tutorial on how to use them. 

394 

395 An alternative to using :class:`ProductDiscreteSpectrumSpace` is to use 

396 the :class:`~.kernels.ProductGeometricKernel` kernel. The latter is more 

397 flexible, allowing more general spaces as factors and automatic 

398 relevance determination -like behavior. 

399 

400 :param spaces: 

401 The factors, subclasses of :class:`~.spaces.DiscreteSpectrumSpace`. 

402 :param num_levels: 

403 The number of levels to pre-compute for this product space. 

404 :param num_levels_per_space: 

405 Number of levels to fetch for each of the factor spaces, to compute 

406 the product-space levels. This is a single number rather than a list, 

407 because we currently only support fetching the same number of levels 

408 for all the factor spaces. 

409 

410 If not given, `num_levels` levels will be fetched for each factor. 

411 

412 .. admonition:: Citation 

413 

414 If you use this GeometricKernels space in your research, please consider 

415 citing :cite:t:`borovitskiy2020`. 

416 """ 

417 

418 def __init__( 

419 self, 

420 *spaces: DiscreteSpectrumSpace, 

421 num_levels: int = 25, 

422 num_levels_per_space: Optional[int] = None, 

423 ): 

424 for space in spaces: 

425 if not isinstance(space, DiscreteSpectrumSpace): 

426 raise ValueError( 

427 "One of the spaces is not an instance of DiscreteSpectrumSpace." 

428 ) 

429 

430 self.factor_spaces = spaces # List of length S 

431 self.num_levels = num_levels 

432 

433 # Perform a breadth-first search for the smallest eigenvalues. 

434 # Assuming that the eigenvalues come sorted, the next biggest eigenvalue 

435 # can be found by taking a one-index step in any direction from the 

436 # current edge of the search space. 

437 

438 if num_levels_per_space is None: 

439 num_levels_per_space = num_levels 

440 if num_levels > num_levels_per_space ** len(spaces): 

441 raise ValueError( 

442 "Cannot have more levels than there are possible combinations." 

443 ) 

444 

445 # prefetch the eigenvalues of the subspaces 

446 factor_space_eigenvalues = B.stack( 

447 *[ 

448 space.get_eigenvalues(num_levels_per_space)[:, 0] 

449 for space in self.factor_spaces 

450 ], 

451 axis=0, 

452 ) # [S, num_levels_per_space] 

453 

454 self.factor_space_eigenindices = _find_lowest_sum_combinations( 

455 factor_space_eigenvalues, self.num_levels 

456 ) # [self.num_levels, S] 

457 self.factor_space_eigenvalues = factor_space_eigenvalues 

458 

459 self._eigenvalues = factor_space_eigenvalues[ 

460 range(len(self.factor_spaces)), 

461 self.factor_space_eigenindices, 

462 ].sum(axis=1) 

463 

464 def __str__(self): 

465 return f"ProductDiscreteSpectrumSpace({', '.join(str(space) for space in self.factor_spaces)})" 

466 

467 @property 

468 def dimension(self) -> int: 

469 """ 

470 Returns the dimension of the product space, equal to the sum of 

471 dimensions of the factors. 

472 """ 

473 return sum([space.dimension for space in self.factor_spaces]) 

474 

475 def random(self, key, number): 

476 """ 

477 Sample random points on the product space by concatenating random points 

478 on the factor spaces via :func:`~.make_product`. 

479 

480 :param key: 

481 Either `np.random.RandomState`, `tf.random.Generator`, 

482 `torch.Generator` or `jax.tensor` (representing random state). 

483 :param number: 

484 Number of samples to draw. 

485 

486 :return: 

487 An array of `number` uniformly random samples on the space. 

488 """ 

489 random_points = [] 

490 for factor in self.factor_spaces: 

491 key, factor_random_points = factor.random(key, number) 

492 random_points.append(factor_random_points) 

493 

494 return key, make_product(random_points) 

495 

496 def get_eigenfunctions(self, num: int) -> Eigenfunctions: 

497 """ 

498 Returns the :class:`~.ProductEigenfunctions` object with `num` levels. 

499 

500 :param num: 

501 Number of levels. Cannot be larger than the `num_levels` parameter 

502 of the constructor. 

503 """ 

504 if num > self.num_levels: 

505 raise ValueError( 

506 "`num` cannot be larger than the `num_levels` provided in the constructor." 

507 ) 

508 

509 max_level = int(self.factor_space_eigenindices[:num, :].max() + 1) 

510 

511 factor_space_eigenfunctions = [ 

512 space.get_eigenfunctions(max_level) for space in self.factor_spaces 

513 ] 

514 

515 return ProductEigenfunctions( 

516 [space.element_shape for space in self.factor_spaces], 

517 [space.element_dtype for space in self.factor_spaces], 

518 self.factor_space_eigenindices[:num], 

519 *factor_space_eigenfunctions, 

520 ) 

521 

522 def get_eigenvalues(self, num: int) -> B.Numeric: 

523 """ 

524 Eigenvalues of the Laplacian corresponding to the first `num` levels. 

525 

526 :param num: 

527 Number of levels. Cannot be larger than the `num_levels` parameter 

528 of the constructor. 

529 :return: 

530 (num, 1)-shaped array containing the eigenvalues. 

531 """ 

532 if num > self.num_levels: 

533 raise ValueError( 

534 "`num` cannot be larger than the `num_levels` provided in the constructor." 

535 ) 

536 

537 return self._eigenvalues[:num, None] 

538 

539 def get_repeated_eigenvalues(self, num: int) -> B.Numeric: 

540 """ 

541 Eigenvalues of the Laplacian corresponding to the first `num` levels, 

542 repeated according to their multiplicity within levels. 

543 

544 :param num: 

545 Number of levels. Cannot be larger than the `num_levels` parameter 

546 of the constructor. 

547 

548 :return: 

549 (J, 1)-shaped array containing the repeated eigenvalues,`J is 

550 the resulting number of the repeated eigenvalues. 

551 """ 

552 if num > self.num_levels: 

553 raise ValueError( 

554 "`num` cannot be larger than the `num_levels` provided in the constructor." 

555 ) 

556 

557 eigenfunctions = self.get_eigenfunctions(num) 

558 eigenvalues = self._eigenvalues[:num] 

559 multiplicities = eigenfunctions.num_eigenfunctions_per_level 

560 

561 repeated_eigenvalues = chain(eigenvalues, multiplicities) 

562 return B.reshape(repeated_eigenvalues, -1, 1) # [J, 1] 

563 

564 @property 

565 def element_shape(self): 

566 """ 

567 :return: 

568 Sum of the products of the element shapes of the factor spaces. 

569 """ 

570 return sum(math.prod(space.element_shape) for space in self.factor_spaces) 

571 

572 @property 

573 def element_dtype(self): 

574 """ 

575 :return: 

576 B.Numeric. 

577 """ 

578 return B.Numeric