Coverage for geometric_kernels/utils/kernel_formulas/euclidean.py: 83%

24 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2Implements the standard formulas for the RBF kernel and some Matérn kernels. 

3 

4The implementation is provided mainly for testing purposes. 

5""" 

6 

7from math import sqrt 

8 

9import lab as B 

10from beartype.typing import Optional 

11 

12 

13def euclidean_matern_12_kernel( 

14 r: B.Numeric, 

15 lengthscale: Optional[float] = 1.0, 

16): 

17 """ 

18 Analytic formula for the Matérn 1/2 kernel on R^d, as a function of 

19 distance `r` between inputs. 

20 

21 :param r: 

22 A batch of distances, an array of shape [...]. 

23 :param lengthscale: 

24 The length scale of the kernel, defaults to 1. 

25 

26 :return: 

27 The kernel values evaluated at `r`, an array of shape [...]. 

28 """ 

29 

30 if not B.all(r >= 0.0): 

31 raise ValueError("Distances must be non-negative.") 

32 

33 return B.exp(-r / lengthscale) 

34 

35 

36def euclidean_matern_32_kernel( 

37 r: B.Numeric, 

38 lengthscale: Optional[float] = 1.0, 

39): 

40 """ 

41 Analytic formula for the Matérn 3/2 kernel on R^d, as a function of 

42 distance `r` between inputs. 

43 

44 :param r: 

45 A batch of distances, an array of shape [...]. 

46 :param lengthscale: 

47 The length scale of the kernel, defaults to 1. 

48 

49 :return: 

50 The kernel values evaluated at `r`, an array of shape [...]. 

51 """ 

52 

53 if not B.all(r >= 0.0): 

54 raise ValueError("Distances must be non-negative.") 

55 

56 sqrt3 = sqrt(3.0) 

57 r = r / lengthscale 

58 return (1.0 + sqrt3 * r) * B.exp(-sqrt3 * r) 

59 

60 

61def euclidean_matern_52_kernel( 

62 r: B.Numeric, 

63 lengthscale: Optional[float] = 1.0, 

64): 

65 """ 

66 Analytic formula for the Matérn 5/2 kernel on R^d, as a function of 

67 distance `r` between inputs. 

68 

69 :param r: 

70 A batch of distances, an array of shape [...]. 

71 :param lengthscale: 

72 The length scale of the kernel, defaults to 1. 

73 

74 :return: 

75 The kernel values evaluated at `r`, an array of shape [...]. 

76 """ 

77 

78 if not B.all(r >= 0.0): 

79 raise ValueError("Distances must be non-negative.") 

80 

81 sqrt5 = sqrt(5.0) 

82 r = r / lengthscale 

83 return (1.0 + sqrt5 * r + 5.0 / 3.0 * (r**2)) * B.exp(-sqrt5 * r) 

84 

85 

86def euclidean_rbf_kernel( 

87 r: B.Numeric, 

88 lengthscale: Optional[float] = 1.0, 

89): 

90 """ 

91 Analytic formula for the RBF kernel on R^d, as a function of 

92 distance `r` between inputs. 

93 

94 :param r: 

95 A batch of distances, an array of shape [...]. 

96 :param lengthscale: 

97 The length scale of the kernel, defaults to 1. 

98 

99 :return: 

100 The kernel values evaluated at `r`, an array of shape [...]. 

101 """ 

102 

103 if not B.all(r >= 0.0): 

104 raise ValueError("Distances must be non-negative.") 

105 

106 r = r / lengthscale 

107 return B.exp(-0.5 * r**2)