-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
262 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (C) 2023 Adam Lugowski. | ||
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file. | ||
# SPDX-License-Identifier: BSD-2-Clause | ||
|
||
from typing import Any, Iterable | ||
|
||
from . import Driver, MatrixSpyAdapter | ||
|
||
|
||
class PyDataSparseDriver(Driver): | ||
@staticmethod | ||
def get_supported_type_prefixes() -> Iterable[str]: | ||
return ["sparse."] | ||
|
||
@staticmethod | ||
def adapt_spy(mat: Any) -> MatrixSpyAdapter: | ||
from .sparse_impl import PyDataSparseSpy | ||
return PyDataSparseSpy(mat) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (C) 2023 Adam Lugowski. | ||
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file. | ||
# SPDX-License-Identifier: BSD-2-Clause | ||
|
||
from typing import Tuple | ||
|
||
import numpy as np | ||
import sparse | ||
|
||
from . import describe, generate_spy_triple_product, MatrixSpyAdapter | ||
|
||
|
||
def generate_spy_triple_product_sparse(matrix_shape, spy_shape) -> Tuple[sparse.SparseArray, sparse.SparseArray]: | ||
# construct a triple product that will scale the matrix | ||
left, right = generate_spy_triple_product(matrix_shape, spy_shape) | ||
|
||
left_shape, (left_rows, left_cols) = left | ||
right_shape, (right_rows, right_cols) = right | ||
left_mat = sparse.COO(coords=(left_rows, left_cols), data=np.ones(len(left_rows)), shape=left_shape) | ||
right_mat = sparse.COO(coords=(right_rows, right_cols), data=np.ones(len(right_rows)), shape=right_shape) | ||
|
||
return left_mat, right_mat | ||
|
||
|
||
class PyDataSparseSpy(MatrixSpyAdapter): | ||
def __init__(self, mat): | ||
super().__init__() | ||
self.mat = mat | ||
|
||
def get_shape(self) -> tuple: | ||
return self.mat.shape | ||
|
||
def describe(self) -> str: | ||
parts = [ | ||
self.mat.format, | ||
] | ||
|
||
return describe(shape=self.mat.shape, | ||
nnz=self.mat.nnz, nz_type=self.mat.dtype, | ||
notes=", ".join(parts)) | ||
|
||
def get_spy(self, spy_shape: tuple) -> np.array: | ||
if isinstance(self.mat, sparse.DOK): | ||
self.mat = self.mat.asformat("coo") | ||
|
||
# construct a triple product that will scale the matrix | ||
left, right = generate_spy_triple_product_sparse(self.mat.shape, spy_shape) | ||
|
||
# save existing matrix data | ||
mat_data_save = self.mat.data | ||
|
||
# replace with all ones | ||
self.mat.data = np.ones(self.mat.data.shape) | ||
|
||
# triple product | ||
try: | ||
spy = left @ self.mat @ right | ||
except ValueError: | ||
# broken matmul on some types | ||
temp = self.mat.asformat("coo") | ||
spy = left @ temp @ right | ||
|
||
# restore original matrix data | ||
self.mat.data = mat_data_save | ||
|
||
return np.array(spy.todense()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (C) 2023 Adam Lugowski. | ||
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file. | ||
# SPDX-License-Identifier: BSD-2-Clause | ||
|
||
import unittest | ||
|
||
try: | ||
import sparse | ||
except ImportError: | ||
sparse = None | ||
|
||
import numpy as np | ||
import scipy.sparse | ||
|
||
from matspy import spy_to_mpl, to_sparkline, to_spy_heatmap | ||
|
||
np.random.seed(123) | ||
|
||
|
||
@unittest.skipIf(sparse is None, "pydata/sparse not installed") | ||
class PyDataSparseTests(unittest.TestCase): | ||
def setUp(self): | ||
self.mats = [ | ||
sparse.COO.from_scipy_sparse(scipy.sparse.random(10, 10, density=0.4)), | ||
sparse.COO.from_scipy_sparse(scipy.sparse.random(5, 10, density=0.4)), | ||
sparse.COO.from_scipy_sparse(scipy.sparse.random(5, 1, density=0.4)), | ||
sparse.COO.from_scipy_sparse(scipy.sparse.coo_matrix(([], ([], [])), shape=(10, 10))), | ||
] | ||
|
||
def test_no_crash(self): | ||
import matplotlib.pyplot as plt | ||
for fmt in "coo", "gcxs", "dok", "csr", "csc": | ||
for source_mat in self.mats: | ||
mat = source_mat.asformat(fmt) | ||
|
||
fig, ax = spy_to_mpl(mat) | ||
plt.close(fig) | ||
|
||
res = to_sparkline(mat) | ||
self.assertGreater(len(res), 10) | ||
|
||
def test_count(self): | ||
arrs = [ | ||
(0, sparse.COO(np.array([[0]]))), | ||
(1, sparse.COO(np.array([[1]]))), | ||
(0, sparse.COO(np.array([[0, 0], [0, 0]]))), | ||
(1, sparse.COO(np.array([[1, 0], [0, 0]]))), | ||
] | ||
|
||
for count, arr in arrs: | ||
area = np.prod(arr.shape) | ||
heatmap = to_spy_heatmap(arr, buckets=1, shading="absolute") | ||
self.assertEqual(len(heatmap), 1) | ||
self.assertAlmostEqual( count / area, heatmap[0][0], places=2) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |