Coverage for tests/spaces/test_hypersphere.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import lab as B 

2import numpy as np 

3 

4from geometric_kernels.spaces import Hypersphere 

5 

6 

7def test_sphere_heat_kernel(): 

8 # Tests that the heat kernel on the sphere solves the heat equation. This 

9 # test only uses torch, as lab doesn't support backend-independent autodiff. 

10 import torch 

11 

12 import geometric_kernels.torch # noqa 

13 from geometric_kernels.kernels import MaternKarhunenLoeveKernel 

14 from geometric_kernels.utils.manifold_utils import manifold_laplacian 

15 

16 _TRUNCATION_LEVEL = 10 

17 

18 # Parameters 

19 grid_size = 4 

20 nb_samples = 10 

21 dimension = 3 

22 

23 # Create manifold 

24 hypersphere = Hypersphere(dim=dimension) 

25 

26 # Generate samples 

27 ts = torch.linspace(0.1, 1, grid_size, requires_grad=True) 

28 xs = torch.tensor( 

29 np.array(hypersphere.random_point(nb_samples)), requires_grad=True 

30 ) 

31 ys = xs 

32 

33 # Define kernel 

34 kernel = MaternKarhunenLoeveKernel(hypersphere, _TRUNCATION_LEVEL, normalize=False) 

35 params = kernel.init_params() 

36 params["nu"] = torch.tensor([torch.inf]) 

37 

38 # Define heat kernel function 

39 def heat_kernel(t, x, y): 

40 params["lengthscale"] = B.reshape(B.sqrt(2 * t), 1) 

41 return kernel.K(params, x, y) 

42 

43 for t in ts: 

44 for x in xs: 

45 for y in ys: 

46 # Compute the derivative of the kernel function wrt t 

47 dfdt, _, _ = torch.autograd.grad( 

48 heat_kernel(t, x[None], y[None]), (t, x, y) 

49 ) 

50 # Compute the Laplacian of the kernel on the manifold 

51 egrad = lambda u: torch.autograd.grad( # noqa 

52 heat_kernel(t, u[None], y[None]), (t, u, y) 

53 )[ 

54 1 

55 ] # noqa 

56 fx = lambda u: heat_kernel(t, u[None], y[None]) # noqa 

57 ehess = lambda u, h: torch.autograd.functional.hvp(fx, u, h)[1] # noqa 

58 lapf = manifold_laplacian(x, hypersphere, egrad, ehess) 

59 

60 # Check that they match 

61 np.testing.assert_allclose(dfdt.detach().numpy(), lapf, atol=1e-3)