Coverage for geometric_kernels/lab_extras/numpy/sparse_extras.py: 84%
44 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-16 21:43 +0000
1import sys
3import lab as B
4import scipy
5import scipy.sparse as sp
6from beartype.typing import Union
7from lab import dispatch
8from plum import Signature
10from .extras import _Numeric
12"""
13SparseArray defines a lab data type that covers all possible sparse
14scipy arrays, so that multiple dispatch works with such arrays.
15"""
16if sys.version_info[:2] <= (3, 8):
17 SparseArray = Union[
18 sp.bsr_matrix,
19 sp.coo_matrix,
20 sp.csc_matrix,
21 sp.csr_matrix,
22 sp.dia_matrix,
23 sp.dok_matrix,
24 sp.lil_matrix,
25 ]
26else:
27 SparseArray = Union[
28 sp.sparray,
29 sp.spmatrix,
30 ]
33@dispatch
34def degree(a: SparseArray): # type: ignore
35 """
36 Given an adjacency matrix `a`, return a diagonal matrix
37 with the col-sums of `a` as main diagonal - this is the
38 degree matrix representing the number of nodes each node
39 is connected to.
40 """
41 d = a.sum(axis=0) # type: ignore
42 return sp.spdiags(d, 0, d.size, d.size)
45@dispatch
46def eigenpairs(L: Union[SparseArray, _Numeric], k: int):
47 """
48 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues
49 of a symmetric positive semi-definite matrix `L`.
50 """
51 if sp.issparse(L) and (k == L.shape[0]):
52 L = L.toarray()
53 if sp.issparse(L):
54 return sp.linalg.eigsh(L, k, sigma=1e-8)
55 else:
56 eigenvalues, eigenvectors = scipy.linalg.eigh(L)
57 return (eigenvalues[:k], eigenvectors[:, :k])
60@dispatch
61def set_value(a: Union[SparseArray, _Numeric], index: int, value: float):
62 """
63 Set a[index] = value.
64 This operation is not done in place and a new array is returned.
65 """
66 a = a.copy()
67 a[index] = value
68 return a
71""" Register methods for simple ops for a sparse array. """
74def pinv(a: Union[SparseArray]):
75 i, j = a.nonzero()
76 if not (i == j).all():
77 raise NotImplementedError(
78 "pinv is not supported for non-diagonal sparse arrays."
79 )
80 else:
81 a = sp.csr_matrix(a.copy())
82 a[i, i] = 1 / a[i, i]
83 return a
86# putting "ignore" here for now, seems like some plum/typing issue
87_SparseArray = Signature(SparseArray) # type: ignore
89B.T.register(lambda a: a.T, _SparseArray)
90B.shape.register(lambda a: a.shape, _SparseArray)
91B.sqrt.register(lambda a: a.sqrt(), _SparseArray)
92B.any.register(lambda a: bool((a == True).sum()), _SparseArray) # noqa
93B.rank.register(lambda a: a.ndim, _SparseArray)
95B.linear_algebra.pinv.register(pinv, _SparseArray)