Coverage for geometric_kernels/spaces/graph.py: 95%

65 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1""" 

2This module provides the :class:`Graph` space. 

3""" 

4 

5import lab as B 

6import numpy as np 

7from beartype.typing import Dict, Tuple 

8 

9from geometric_kernels.lab_extras import ( 

10 degree, 

11 dtype_integer, 

12 eigenpairs, 

13 reciprocal_no_nan, 

14 set_value, 

15) 

16from geometric_kernels.spaces.base import DiscreteSpectrumSpace 

17from geometric_kernels.spaces.eigenfunctions import ( 

18 Eigenfunctions, 

19 EigenfunctionsFromEigenvectors, 

20) 

21from geometric_kernels.utils.utils import _check_matrix 

22 

23 

24class Graph(DiscreteSpectrumSpace): 

25 """ 

26 The GeometricKernels space representing the node set of any user-provided 

27 weighted undirected graph. 

28 

29 The elements of this space are represented by node indices, integer values 

30 from 0 to n-1, where n is the number of nodes in the user-provided graph. 

31 

32 Each individual eigenfunction constitutes a *level*. 

33 

34 .. note:: 

35 A tutorial on how to use this space is available in the 

36 :doc:`Graph.ipynb </examples/Graph>` notebook. 

37 

38 :param adjacency_matrix: 

39 An n-dimensional square, symmetric matrix, where 

40 adjacency_matrix[i, j] is non-zero if there is an edge 

41 between nodes i and j. SciPy's sparse matrices are supported. 

42 

43 .. warning:: 

44 Make sure that the type of the `adjacency_matrix` is of the 

45 backend (NumPy (or SciPy) / JAX / TensorFlow, PyTorch) that 

46 you wish to use for internal computations. 

47 

48 :param normalize_laplacian: 

49 If True, the graph Laplacian will be degree normalized (symmetrically): 

50 L_sym = D^-0.5 * L * D^-0.5. 

51 

52 Defaults to False. 

53 

54 .. admonition:: Citation 

55 

56 If you use this GeometricKernels space in your research, please consider 

57 citing :cite:t:`borovitskiy2021`. 

58 """ 

59 

60 def __init__(self, adjacency_matrix: B.Numeric, normalize_laplacian: bool = False): # type: ignore 

61 self.cache: Dict[int, Tuple[B.Numeric, B.Numeric]] = {} 

62 self._checks(adjacency_matrix) 

63 self._set_laplacian(adjacency_matrix, normalize_laplacian) # type: ignore 

64 self._normalized = normalize_laplacian 

65 

66 def __str__(self): 

67 return f"Graph({self.num_vertices}, {'normalized' if self._normalized else 'unnormalized'})" 

68 

69 @staticmethod 

70 def _checks(adjacency_matrix): 

71 """ 

72 Checks if `adjacency_matrix` is a square symmetric matrix. 

73 """ 

74 _check_matrix(adjacency_matrix, "adjacency_matrix") 

75 if B.shape(adjacency_matrix)[0] != B.shape(adjacency_matrix)[1]: 

76 raise ValueError("`adjacency_matrix` must be a square matrix.") 

77 

78 if B.any(adjacency_matrix != B.T(adjacency_matrix)): 

79 raise ValueError("`adjacency_matrix` must be a symmetric matrix.") 

80 

81 @property 

82 def dimension(self) -> int: 

83 """ 

84 :return: 

85 0. 

86 """ 

87 return 0 # this is needed for the kernel math to work out 

88 

89 @property 

90 def num_vertices(self) -> int: 

91 """ 

92 Number of vertices in the graph. 

93 """ 

94 return self._laplacian.shape[0] 

95 

96 def _set_laplacian(self, adjacency, normalize_laplacian=False): 

97 """ 

98 Construct the appropriate graph Laplacian from the adjacency matrix. 

99 """ 

100 degree_matrix = degree(adjacency) 

101 self._laplacian = degree_matrix - adjacency 

102 if normalize_laplacian: 

103 degree_inv_sqrt = reciprocal_no_nan(B.sqrt(degree_matrix)) 

104 self._laplacian = degree_inv_sqrt @ self._laplacian @ degree_inv_sqrt 

105 

106 def get_eigensystem(self, num): 

107 """ 

108 Returns the first `num` eigenvalues and eigenvectors of the graph Laplacian. 

109 Caches the solution to prevent re-computing the same values. 

110 

111 .. note:: 

112 If the `adjacency_matrix` was a sparse SciPy array, requesting 

113 **all** eigenpairs will lead to a conversion of the sparse matrix 

114 to a dense one due to scipy.sparse.linalg.eigsh limitations. 

115 

116 :param num: 

117 Number of eigenpairs to return. Performs the computation at the 

118 first call. Afterwards, fetches the result from cache. 

119 

120 :return: 

121 A tuple of eigenvectors [n, num], eigenvalues [num, 1]. 

122 """ 

123 if num > self.num_vertices: 

124 raise ValueError( 

125 "Number of eigenpairs cannot exceed the number of vertices." 

126 ) 

127 if num not in self.cache: 

128 evals, evecs = eigenpairs(self._laplacian, num) 

129 

130 evecs *= B.sqrt(self.num_vertices) 

131 

132 eps = np.finfo(float).eps 

133 for i, evalue in enumerate(evals): 

134 if evalue < eps or evalue < 0: 

135 evals = set_value(evals, i, eps) # lowest eigenvals should be zero 

136 

137 self.cache[num] = (evecs, evals[:, None]) 

138 

139 return self.cache[num] 

140 

141 def get_eigenfunctions(self, num: int) -> Eigenfunctions: 

142 """ 

143 Returns the :class:`~.EigenfunctionsFromEigenvectors` object with 

144 `num` levels (i.e., in this case, `num` eigenpairs). 

145 

146 :param num: 

147 Number of levels. 

148 """ 

149 eigenfunctions = EigenfunctionsFromEigenvectors(self.get_eigenvectors(num)) 

150 return eigenfunctions 

151 

152 def get_eigenvectors(self, num: int) -> B.Numeric: 

153 """ 

154 :param num: 

155 Number of eigenvectors to return. 

156 

157 :return: 

158 Array of eigenvectors, with shape [n, num]. 

159 """ 

160 return self.get_eigensystem(num)[0] 

161 

162 def get_eigenvalues(self, num: int) -> B.Numeric: 

163 """ 

164 :param num: 

165 Number of eigenvalues to return. 

166 

167 :return: 

168 Array of eigenvalues, with shape [num, 1]. 

169 """ 

170 return self.get_eigensystem(num)[1] 

171 

172 def get_repeated_eigenvalues(self, num: int) -> B.Numeric: 

173 """ 

174 Same as :meth:`get_eigenvalues`. 

175 

176 :param num: 

177 Same as :meth:`get_eigenvalues`. 

178 """ 

179 return self.get_eigenvalues(num) 

180 

181 def random(self, key, number): 

182 num_vertices = B.shape(self._laplacian)[0] 

183 key, random_vertices = B.randint( 

184 key, dtype_integer(key), number, 1, lower=0, upper=num_vertices 

185 ) 

186 

187 return key, random_vertices 

188 

189 @property 

190 def element_shape(self): 

191 """ 

192 :return: 

193 [1]. 

194 """ 

195 return [1] 

196 

197 @property 

198 def element_dtype(self): 

199 """ 

200 :return: 

201 B.Int. 

202 """ 

203 return B.Int