Coverage for geometric_kernels/lab_extras/numpy/sparse_extras.py: 84%

44 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import sys 

2 

3import lab as B 

4import scipy 

5import scipy.sparse as sp 

6from beartype.typing import Union 

7from lab import dispatch 

8from plum import Signature 

9 

10from .extras import _Numeric 

11 

12""" 

13SparseArray defines a lab data type that covers all possible sparse 

14scipy arrays, so that multiple dispatch works with such arrays. 

15""" 

16if sys.version_info[:2] <= (3, 8): 

17 SparseArray = Union[ 

18 sp.bsr_matrix, 

19 sp.coo_matrix, 

20 sp.csc_matrix, 

21 sp.csr_matrix, 

22 sp.dia_matrix, 

23 sp.dok_matrix, 

24 sp.lil_matrix, 

25 ] 

26else: 

27 SparseArray = Union[ 

28 sp.sparray, 

29 sp.spmatrix, 

30 ] 

31 

32 

33@dispatch 

34def degree(a: SparseArray): # type: ignore 

35 """ 

36 Given an adjacency matrix `a`, return a diagonal matrix 

37 with the col-sums of `a` as main diagonal - this is the 

38 degree matrix representing the number of nodes each node 

39 is connected to. 

40 """ 

41 d = a.sum(axis=0) # type: ignore 

42 return sp.spdiags(d, 0, d.size, d.size) 

43 

44 

45@dispatch 

46def eigenpairs(L: Union[SparseArray, _Numeric], k: int): 

47 """ 

48 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues 

49 of a symmetric positive semi-definite matrix `L`. 

50 """ 

51 if sp.issparse(L) and (k == L.shape[0]): 

52 L = L.toarray() 

53 if sp.issparse(L): 

54 return sp.linalg.eigsh(L, k, sigma=1e-8) 

55 else: 

56 eigenvalues, eigenvectors = scipy.linalg.eigh(L) 

57 return (eigenvalues[:k], eigenvectors[:, :k]) 

58 

59 

60@dispatch 

61def set_value(a: Union[SparseArray, _Numeric], index: int, value: float): 

62 """ 

63 Set a[index] = value. 

64 This operation is not done in place and a new array is returned. 

65 """ 

66 a = a.copy() 

67 a[index] = value 

68 return a 

69 

70 

71""" Register methods for simple ops for a sparse array. """ 

72 

73 

74def pinv(a: Union[SparseArray]): 

75 i, j = a.nonzero() 

76 if not (i == j).all(): 

77 raise NotImplementedError( 

78 "pinv is not supported for non-diagonal sparse arrays." 

79 ) 

80 else: 

81 a = sp.csr_matrix(a.copy()) 

82 a[i, i] = 1 / a[i, i] 

83 return a 

84 

85 

86# putting "ignore" here for now, seems like some plum/typing issue 

87_SparseArray = Signature(SparseArray) # type: ignore 

88 

89B.T.register(lambda a: a.T, _SparseArray) 

90B.shape.register(lambda a: a.shape, _SparseArray) 

91B.sqrt.register(lambda a: a.sqrt(), _SparseArray) 

92B.any.register(lambda a: bool((a == True).sum()), _SparseArray) # noqa 

93B.rank.register(lambda a: a.ndim, _SparseArray) 

94 

95B.linear_algebra.pinv.register(pinv, _SparseArray)