Coverage for geometric_kernels/feature_maps/probability_densities.py: 83%

152 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provide the routines for sampling from the Gaussian and Student-t 

3probability densities in a backend-agnostic way. It also provides the routines 

4for sampling the non-standard probability densities that arise in relation to 

5the :class:`~.spaces.Hyperbolic` and 

6:class:`~.spaces.SymmetricPositiveDefiniteMatrices` spaces. 

7""" 

8 

9import operator 

10from functools import reduce 

11from typing import List, Tuple 

12 

13import lab as B 

14import numpy as np 

15from beartype.typing import Dict, List, Optional, Tuple 

16from sympy import Poly, Product, symbols 

17 

18from geometric_kernels.lab_extras import ( 

19 cumsum, 

20 dtype_double, 

21 dtype_integer, 

22 eigvalsh, 

23 from_numpy, 

24) 

25from geometric_kernels.utils.utils import ( 

26 _check_1_vector, 

27 _check_field_in_params, 

28 _check_matrix, 

29 _check_rank_1_array, 

30 ordered_pairwise_differences, 

31) 

32 

33 

34def student_t_sample( 

35 key: B.RandomState, 

36 loc: B.Numeric, 

37 shape: B.Numeric, 

38 df: B.Numeric, 

39 size: Tuple[int, ...], 

40 dtype: Optional[B.DType] = None, 

41) -> Tuple[B.RandomState, B.Numeric]: 

42 r""" 

43 Sample from the multivariate Student's t-distribution with mean vector 

44 `loc`, shape matrix `shape` and `df` degrees of freedom, using `key` random 

45 state, and returning sample of shape `size`. 

46 

47 A multivariate Student's t random vector $T$ with $\nu$ degrees of freedom, 

48 a mean vector $\mu$ and a shape matrix $\Sigma$ can be represented as 

49 

50 .. math:: T = \frac{Z}{\sqrt{V/\nu}} + \mu, 

51 

52 where $Z \sim N(0, \Sigma)$ and $V \sim \chi^2(\nu)$. The $\chi^2(\nu)$ 

53 distribution, in its turn, is the same as $\Gamma(\nu / 2, 2)$ distribution, 

54 and therefore $V/\nu \sim \Gamma(\nu / 2, 2 / \nu)$. We use these properties 

55 to sample a multivariate Student's t random vector by sampling a Gaussian 

56 random vector and a Gamma random variable. 

57 

58 .. warning:: The `shape` parameter has little to do with tensor shapes, 

59 it is similar to the covariance matrix in the Gaussian distribution. 

60 

61 :param key: 

62 Either `np.random.RandomState`, `tf.random.Generator`, 

63 `torch.Generator` or `jax.tensor` (representing random state). 

64 :param loc: 

65 Mean vector of the multivariate Student's t-distribution. 

66 An array of shape (n,). 

67 :param shape: 

68 Shape matrix of the multivariate Student's t-distribution. 

69 An array of shape (n, n). 

70 :param size: 

71 The returned array `samples` will have the shape `(*size, n)`. 

72 :param df: 

73 The number of degrees of freedom of the Student-t distribution, 

74 represented as a (1,)-array of the used backend. 

75 :param dtype: 

76 dtype of the returned tensor. 

77 

78 :return: 

79 `Tuple(key, samples)` where `samples` is a `(*size, n)`-shaped array of 

80 samples of type `dtype`, and `key` is the updated random key for `jax`, 

81 or the similar random state (generator) for any other backend. 

82 """ 

83 _check_1_vector(df, "df") 

84 

85 _check_rank_1_array(loc, "loc") 

86 _check_matrix(shape, "shape") 

87 

88 n = B.shape(loc)[0] 

89 

90 if tuple(B.shape(shape)) != (n, n): 

91 raise ValueError( 

92 f"`Expected `shape` matrix to have shape [{n}, {n}], but got {B.shape(shape)}." 

93 ) 

94 shape_sqrt = B.chol(shape) 

95 dtype = dtype or dtype_double(key) 

96 key, z = B.randn(key, dtype, *size, n) 

97 z = B.einsum("...i,ji->...j", z, shape_sqrt) 

98 

99 key, g = B.randgamma( 

100 key, 

101 dtype, 

102 *size, 

103 alpha=df / 2, 

104 scale=2, 

105 ) 

106 

107 g = B.squeeze(g, axis=-1) 

108 

109 u = (z / B.sqrt(g / df)[..., None]) + loc 

110 

111 return key, u 

112 

113 

114def base_density_sample( 

115 key: B.RandomState, 

116 size: Tuple[int, ...], 

117 params: Dict[str, B.Numeric], 

118 dim: int, 

119 rho: B.Numeric, 

120 shifted_laplacian: bool = True, 

121) -> Tuple[B.RandomState, B.Numeric]: 

122 r""" 

123 The Matérn kernel's spectral density is $p_{\nu,\kappa}(\lambda)$, where 

124 $\nu$ is the smoothness parameter, $\kappa$ is the length scale and 

125 $p_{\nu,\kappa}$ is the Student's t or Gaussian density, depending on the 

126 smoothness. 

127 

128 We call it "base density" and this function returns a sample from it. 

129 

130 :param key: 

131 Either `np.random.RandomState`, `tf.random.Generator`, 

132 `torch.Generator` or `jax.tensor` (representing random state). 

133 :param size: 

134 The returned array `samples` will have the shape `(*size, dim)`. 

135 :param params: 

136 Params of the kernel. 

137 :param dim: 

138 Dimensionality of the space the kernel is defined on. 

139 :param rho: 

140 $\rho$ vector of the space. 

141 :param shifted_laplacian: 

142 If True, assumes that the kernels are defined in terms of the shifted 

143 Laplacian. This often makes Matérn kernels more flexible by widening 

144 the effective range of the length scale parameter. 

145 

146 Defaults to True. 

147 

148 :return: 

149 `Tuple(key, samples)` where `samples` is a `(*size, dim)`-shaped array 

150 of samples, and `key` is the updated random key for `jax`, or the 

151 similar random state (generator) for any other backend. 

152 """ 

153 _check_field_in_params(params, "lengthscale") 

154 _check_1_vector(params["lengthscale"], 'params["lengthscale"]') 

155 

156 _check_field_in_params(params, "nu") 

157 _check_1_vector(params["nu"], 'params["nu"]') 

158 

159 nu = params["nu"] 

160 L = params["lengthscale"] 

161 

162 rho_size = B.size(rho) 

163 

164 # Note: 1.0 in safe_nu can be replaced by any finite positive value 

165 safe_nu = B.where(nu == np.inf, B.cast(B.dtype(L), np.r_[1.0]), nu) 

166 

167 # for nu == np.inf 

168 # sample from Gaussian 

169 key, u_nu_infinite = B.randn(key, B.dtype(L), *size, rho_size) 

170 # for nu < np.inf 

171 # sample from the student-t with 2\nu + dim(space) - dim(rho) degrees of freedom 

172 df = 2 * safe_nu + dim - rho_size 

173 

174 dtype = B.dtype(L) 

175 

176 key, u_nu_finite = student_t_sample( 

177 key, 

178 B.zeros(dtype, rho_size), 

179 B.eye(dtype, rho_size), 

180 df, 

181 size, 

182 dtype, 

183 ) 

184 

185 u = B.where(nu == np.inf, u_nu_infinite, u_nu_finite) 

186 

187 scale_nu_infinite = L 

188 if shifted_laplacian: 

189 scale_nu_finite = B.sqrt(df / (2 * safe_nu / L**2)) 

190 else: 

191 scale_nu_finite = B.sqrt(df / (2 * safe_nu / L**2 + B.sum(rho**2))) 

192 

193 scale = B.where(nu == np.inf, scale_nu_infinite, scale_nu_finite) 

194 

195 scale = B.cast(B.dtype(u), scale) 

196 return key, u / scale 

197 

198 

199def _randcat_fix( 

200 key: B.RandomState, dtype: B.DType, size: int, p: B.Numeric 

201) -> Tuple[B.RandomState, B.Numeric]: 

202 """ 

203 Sample from the categorical variable with probabilities `p`. 

204 

205 :param key: 

206 Either `np.random.RandomState`, `tf.random.Generator`, 

207 `torch.Generator` or `jax.tensor` (representing random state). 

208 :param dtype: 

209 dtype of the returned tensor. 

210 :param size: 

211 Number of samples returned. 

212 :param p: 

213 Vector of (potentially unnormalized) probabilities 

214 defining the discrete distribution to sample from. 

215 

216 :return: 

217 `Tuple(key, samples)` where `samples` is a `(size,)`-shaped array of 

218 samples of type `dtype`, and `key` is the updated random key for `jax`, 

219 or the similar random state (generator) for any other backend. 

220 """ 

221 p = p / B.sum(p, axis=-1, squeeze=False) 

222 # Perform sampling routine. 

223 cdf = cumsum(p, axis=-1) 

224 key, u = B.rand(key, dtype, size, *B.shape(p)[:-1]) 

225 inds = B.argmax(B.cast(dtype_integer(key), u[..., None] < cdf[None]), axis=-1) 

226 return key, B.cast(dtype, inds) 

227 

228 

229def _alphas(n: int) -> B.Numeric: 

230 r""" 

231 Compute alphas for Prop. 16 & 17 of cite:t:`azangulov2024b` 

232 for the hyperbolic space of dimension `n`. 

233 

234 :param n: 

235 Dimension of the hyperbolic space, n >= 2. 

236 

237 :return: 

238 Array of alphas. 

239 

240 .. todo:: 

241 Precompute these, rather than computing in runtime. 

242 

243 .. todo:: 

244 Update proposition numbers when the paper gets published. 

245 """ 

246 if n < 2: 

247 raise ValueError("Dimension of the hyperbolic space `n` must be >= 2.") 

248 x, j = symbols("x, j") 

249 if (n % 2) == 0: 

250 m = n // 2 

251 prod = x * Product(x**2 + (2 * j - 3) ** 2 / 4, (j, 2, m)).doit() 

252 else: 

253 m = (n - 1) // 2 

254 prod = Product(x**2 + j**2, (j, 0, m - 1)).doit() 

255 return np.array(Poly(prod, x).all_coeffs()).astype(np.float64)[::-1] 

256 

257 

258def _sample_mixture_heat( 

259 key: B.RandomState, alpha: B.Numeric, lengthscale: B.Numeric 

260) -> Tuple[B.RandomState, B.Numeric]: 

261 r""" 

262 Sample from the mixture distribution from Prop. 16 for specific alphas 

263 `alpha` and length scale ($\kappa$) `lengthscale` using `key` random state. 

264 

265 :param key: 

266 Either `np.random.RandomState`, `tf.random.Generator`, 

267 `torch.Generator` or `jax.tensor` (representing random state). 

268 :param alpha: 

269 Unnormalized coefficients of the mixture. 

270 :param lengthscale: 

271 Length scale ($\kappa$). 

272 

273 :return: 

274 `Tuple(key, sample)` where `sample` is a single sample from the 

275 distribution, and `key` is the updated random key for `jax`, or the 

276 similar random state (generator) for any other backend. 

277 

278 .. todo:: 

279 Do we need reparameterization trick for hyperparameter optimization? 

280 

281 .. todo:: 

282 Update proposition numbers when the paper gets published. 

283 """ 

284 _check_rank_1_array(alpha, "alpha") 

285 m = B.shape(alpha)[0] - 1 

286 if m < 0: 

287 raise ValueError("The mixture must contain at least 1 component.") 

288 dtype = B.dtype(lengthscale) 

289 js = B.range(dtype, 0, m + 1) 

290 

291 # Gamma((js+1)/2) should be positive real, so G = exp(log(abs(G))) 

292 beta = 2 ** ((1 - js) / 2) / B.exp(B.loggamma((js + 1) / 2)) * lengthscale 

293 

294 alpha = B.cast(dtype, from_numpy(beta, alpha)) 

295 cs_unnorm = alpha / beta 

296 cs = cs_unnorm / B.sum(cs_unnorm) 

297 key, ind = _randcat_fix(key, dtype, 1, cs) 

298 

299 # Gamma(nu/2, 2) distribution is the same as chi2(nu) distribution 

300 key, s = B.randgamma(key, dtype, 1, alpha=(ind + 1) / 2, scale=2) 

301 s = B.sqrt(s) / lengthscale 

302 return key, s 

303 

304 

305def _sample_mixture_matern( 

306 key: B.RandomState, 

307 alpha: B.Numeric, 

308 lengthscale: B.Numeric, 

309 nu: B.Numeric, 

310 dim: int, 

311 shifted_laplacian: bool = True, 

312) -> Tuple[B.RandomState, B.Numeric]: 

313 r""" 

314 Sample from the mixture distribution from Prop. 17 of 

315 cite:t:`azangulov2024b` for specific alphas `alpha`, length 

316 scale ($\kappa$) `lengthscale`, smoothness `nu` and dimension `dim`, using 

317 `key` random state. 

318 

319 :param key: 

320 Either `np.random.RandomState`, `tf.random.Generator`, 

321 `torch.Generator` or `jax.tensor` (representing random state). 

322 :param alpha: 

323 Unnormalized coefficients of the mixture. 

324 :param lengthscale: 

325 Length scale ($\kappa$). 

326 :param nu: 

327 Smoothness parameter of Matérn kernels ($\nu$). 

328 :param dim: 

329 Dimension of the hyperbolic space. 

330 :param shifted_laplacian: 

331 If True, assumes that the kernels are defined in terms of the shifted 

332 Laplacian. This often makes Matérn kernels more flexible by widening 

333 the effective range of the length scale parameter. 

334 

335 Defaults to True. 

336 

337 :return: 

338 `Tuple(key, sample)` where `sample` is a single sample from the 

339 distribution, and `key` is the updated random key for `jax`, or the 

340 similar random state (generator) for any other backend. 

341 

342 .. todo:: 

343 Do we need reparameterization trick for hyperparameter optimization? 

344 

345 .. todo:: 

346 Update proposition numbers when the paper gets published. 

347 """ 

348 _check_rank_1_array(alpha, "alpha") 

349 m = B.shape(alpha)[0] - 1 

350 if m < 0: 

351 raise ValueError("The mixture must contain at least 1 component.") 

352 dtype = B.dtype(lengthscale) 

353 js = B.range(dtype, 0, m + 1) 

354 if shifted_laplacian: 

355 gamma = 2 * nu / lengthscale**2 

356 else: 

357 gamma = 2 * nu / lengthscale**2 + ((dim - 1) / 2) ** 2 

358 

359 # B(x, y) = Gamma(x) Gamma(y) / Gamma(x+y) 

360 beta = B.exp( 

361 B.loggamma((js + 1) / 2) 

362 + B.loggamma(nu + (dim - js - 1) / 2) 

363 - B.loggamma(nu + dim / 2) 

364 ) 

365 beta = 2.0 / beta 

366 beta = beta * B.power(gamma, nu + (dim - js - 1) / 2) 

367 

368 alpha = B.cast(dtype, from_numpy(beta, alpha)) 

369 cs_unnorm = alpha / beta 

370 cs = cs_unnorm / B.sum(cs_unnorm) 

371 

372 key, ind = _randcat_fix(key, dtype, 1, cs) 

373 key, p = B.randbeta( 

374 key, dtype, 1, alpha=(ind + 1) / 2, beta=nu + (dim - ind - 1) / 2 

375 ) 

376 p = p / (1 - p) # beta prime distribution 

377 s = B.sqrt(gamma * p) 

378 return key, s 

379 

380 

381def hyperbolic_density_sample( 

382 key: B.RandomState, 

383 size: Tuple[int, ...], 

384 params: Dict[str, B.Numeric], 

385 dim: int, 

386 shifted_laplacian: bool = True, 

387) -> Tuple[B.RandomState, B.Numeric]: 

388 r""" 

389 This function samples from the full (i.e., including the $c(\lambda)^{-2}$ 

390 factor, where $c$ is the Harish-Chandra's $c$-function) spectral density of 

391 the heat/Matérn kernel on the hyperbolic space, using rejection sampling. 

392 

393 :param key: 

394 Either `np.random.RandomState`, `tf.random.Generator`, 

395 `torch.Generator` or `jax.tensor` (representing random state). 

396 :param size: 

397 Shape of the returned sample. 

398 :param params: 

399 Params of the kernel. 

400 :param dim: 

401 Dimensionality of the space the kernel is defined on. 

402 :param shifted_laplacian: 

403 If True, assumes that the kernels are defined in terms of the shifted 

404 Laplacian. This often makes Matérn kernels more flexible by widening 

405 the effective range of the length scale parameter. 

406 

407 Defaults to True. 

408 

409 :return: 

410 `Tuple(key, samples)` where `samples` is a `size`-shaped array of 

411 samples, and `key` is the updated random key for `jax`, or the similar 

412 random state (generator) for any other backend. 

413 """ 

414 _check_field_in_params(params, "lengthscale") 

415 _check_1_vector(params["lengthscale"], 'params["lengthscale"]') 

416 

417 _check_field_in_params(params, "nu") 

418 _check_1_vector(params["nu"], 'params["nu"]') 

419 

420 nu = params["nu"] 

421 L = params["lengthscale"] 

422 

423 alpha = _alphas(dim) 

424 

425 def base_sampler(key): 

426 # Note: 1.0 in safe_nu can be replaced by any finite positive value 

427 safe_nu = B.where(nu == np.inf, B.cast(B.dtype(L), np.r_[1.0]), nu) 

428 

429 # for nu == np.inf 

430 key, sample_mixture_nu_infinite = _sample_mixture_heat(key, alpha, L) 

431 # for nu < np.inf 

432 key, sample_mixture_nu_finite = _sample_mixture_matern( 

433 key, alpha, L, safe_nu, dim, shifted_laplacian 

434 ) 

435 

436 return key, B.where( 

437 nu == np.inf, sample_mixture_nu_infinite, sample_mixture_nu_finite 

438 ) 

439 

440 samples: List[B.Numeric] = [] 

441 while len(samples) < reduce(operator.mul, size, 1): 

442 key, proposal = base_sampler(key) 

443 

444 # accept with probability tanh(pi*proposal) 

445 key, u = B.rand(key, B.dtype(L), 1) 

446 acceptance = B.all(u < B.tanh(B.pi * proposal)) 

447 

448 key, sign_z = B.rand(key, B.dtype(L), 1) 

449 sign = B.sign(sign_z - 0.5) # +1 or -1 with probability 0.5 

450 if ((dim % 2) == 1) or acceptance: 

451 samples.append(sign * proposal) 

452 samples = B.reshape(B.concat(*samples), *size) 

453 return key, B.cast(B.dtype(L), samples) 

454 

455 

456def spd_density_sample( 

457 key: B.RandomState, 

458 size: Tuple[int, ...], 

459 params: Dict[str, B.Numeric], 

460 degree: int, 

461 rho: B.Numeric, 

462 shifted_laplacian: bool = True, 

463) -> Tuple[B.RandomState, B.Numeric]: 

464 r""" 

465 This function samples from the full (i.e., including the $c(\lambda)^{-2}$ 

466 factor, where $c$ is the Harish-Chandra's $c$-function) spectral density 

467 of the heat/Matérn kernel on the manifold of symmetric positive definite 

468 matrices, using rejection sampling. 

469 

470 :param key: 

471 Either `np.random.RandomState`, `tf.random.Generator`, 

472 `torch.Generator` or `jax.tensor` (representing random state). 

473 :param size: 

474 The returned array `samples` will have the shape `(*size, D)`. 

475 :param params: 

476 Params of the kernel. 

477 :param dim: 

478 Dimensionality of the space the kernel is defined on. 

479 :param degree: 

480 The degree D of the SPD(D) space. 

481 :param rho: 

482 $\rho$ vector of the space, D-dimensional. 

483 :param shifted_laplacian: 

484 If True, assumes that the kernels are defined in terms of the shifted 

485 Laplacian. This often makes Matérn kernels more flexible by widening 

486 the effective range of the length scale parameter. 

487 

488 Defaults to True. 

489 

490 :return: 

491 `Tuple(key, samples)` where `samples` is a `(*size, D)`-shaped array of 

492 samples, and `key` is the updated random key for `jax`, or the similar 

493 random state (generator) for any other backend. 

494 """ 

495 _check_field_in_params(params, "nu") 

496 _check_1_vector(params["nu"], 'params["nu"]') 

497 

498 _check_field_in_params(params, "lengthscale") 

499 _check_1_vector(params["lengthscale"], 'params["lengthscale"]') 

500 

501 nu = params["nu"] 

502 L = params["lengthscale"] 

503 

504 samples: List[B.Numeric] = [] 

505 while len(samples) < reduce(operator.mul, size, 1): 

506 key, X = B.randn(key, B.dtype(L), degree, degree) 

507 M = (X + B.transpose(X)) / 2 

508 

509 eigv = eigvalsh(M) # [D] 

510 

511 # Note: 1.0 in safe_nu can be replaced by any finite positive value 

512 safe_nu = B.where(nu == np.inf, B.cast(B.dtype(L), np.r_[1.0]), nu) 

513 

514 # for nu == np.inf 

515 proposal_nu_infinite = eigv / L 

516 

517 # for nu < np.inf 

518 if shifted_laplacian: 

519 eigv_nu_finite = eigv * B.sqrt(2 * safe_nu) / L 

520 else: 

521 eigv_nu_finite = eigv * B.sqrt(2 * safe_nu / L**2 + B.sum(rho**2)) 

522 # Gamma(nu, 2) distribution is the same as chi2(2nu) distribution 

523 key, chi2_sample = B.randgamma(key, B.dtype(L), 1, alpha=safe_nu, scale=2) 

524 chi_sample = B.sqrt(chi2_sample) 

525 proposal_nu_finite = eigv_nu_finite / chi_sample # [D] 

526 

527 proposal = B.where(nu == np.inf, proposal_nu_infinite, proposal_nu_finite) 

528 

529 diffp = ordered_pairwise_differences(proposal) 

530 diffp = B.pi * B.abs(diffp) 

531 logprod = B.sum(B.log(B.tanh(diffp)), axis=-1) 

532 prod = B.exp(logprod) 

533 

534 # accept with probability `prod` 

535 key, u = B.rand(key, B.dtype(L), 1) 

536 acceptance = B.all(u < prod) 

537 if acceptance: 

538 samples.append(proposal) 

539 

540 samples = B.reshape(B.concat(*samples), *size, degree) 

541 return key, samples