Coverage for tests/utils/test_manifold_utils.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import numpy as np 

2import pytest 

3 

4from geometric_kernels.spaces import Hyperbolic 

5from geometric_kernels.utils.manifold_utils import hyperbolic_distance 

6 

7from ..helper import check_function_with_backend 

8 

9 

10@pytest.mark.parametrize("dim", [2, 3, 9, 10]) 

11@pytest.mark.parametrize("backend", ["numpy", "tensorflow", "torch", "jax"]) 

12def test_hyperboloid_distance(dim, backend): 

13 space = Hyperbolic(dim=dim) 

14 

15 key = np.random.RandomState(0) 

16 N, N2 = key.randint(low=2, high=15, size=2) 

17 key, X = space.random(key, N) 

18 key, X2 = space.random(key, N2) 

19 

20 X_expanded = np.tile(X[..., None, :], (1, X2.shape[0], 1)) # (N, M, n+1) 

21 X2_expanded = np.tile(X2[None], (X.shape[0], 1, 1)) # (N, M, n+1) 

22 result = space.metric.dist(X_expanded, X2_expanded) 

23 

24 # Check that our implementation of the hyperbolic distance coincides with 

25 # the one from geomstats. 

26 check_function_with_backend( 

27 backend, 

28 result, 

29 hyperbolic_distance, 

30 X, 

31 X2, 

32 )