diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 35bd0d8b..ac84a64e 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -48,12 +48,13 @@ jobs: - name: Install main package run: | - pip install -e .[test] + python setup.py develop env: WITH_METIS: 1 - name: Run test-suite run: | + pip install pytest pytest-cov pytest --cov --cov-report=xml - name: Upload coverage diff --git a/test/test_mul.py b/test/test_mul.py new file mode 100644 index 00000000..9e3d08e6 --- /dev/null +++ b/test/test_mul.py @@ -0,0 +1,53 @@ +from itertools import product + +import pytest +import torch + +from torch_sparse import SparseTensor, mul +from torch_sparse.testing import devices, dtypes, tensor + + +@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) +def test_sparse_sparse_mul(dtype, device): + rowA = torch.tensor([0, 0, 1, 2, 2], device=device) + colA = torch.tensor([0, 2, 1, 0, 1], device=device) + valueA = tensor([1, 2, 4, 1, 3], dtype, device) + A = SparseTensor(row=rowA, col=colA, value=valueA) + + rowB = torch.tensor([0, 0, 1, 2, 2], device=device) + colB = torch.tensor([1, 2, 2, 1, 2], device=device) + valueB = tensor([2, 3, 1, 2, 4], dtype, device) + B = SparseTensor(row=rowB, col=colB, value=valueB) + + C = A * B + rowC, colC, valueC = C.coo() + + assert rowC.tolist() == [0, 2] + assert colC.tolist() == [2, 1] + assert valueC.tolist() == [6, 6] + + @torch.jit.script + def jit_mul(A: SparseTensor, B: SparseTensor) -> SparseTensor: + return mul(A, B) + + jit_mul(A, B) + + +@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) +def test_sparse_sparse_mul_empty(dtype, device): + rowA = torch.tensor([0], device=device) + colA = torch.tensor([1], device=device) + valueA = tensor([1], dtype, device) + A = SparseTensor(row=rowA, col=colA, value=valueA) + + rowB = torch.tensor([1], device=device) + colB = torch.tensor([0], device=device) + valueB = tensor([2], dtype, device) + B = SparseTensor(row=rowB, col=colB, value=valueB) + + C = A * B + rowC, colC, valueC = C.coo() + + assert rowC.tolist() == [] + assert colC.tolist() == [] + assert valueC.tolist() == [] diff --git a/torch_sparse/mul.py b/torch_sparse/mul.py index 854f3fa9..7b467906 100644 --- a/torch_sparse/mul.py +++ b/torch_sparse/mul.py @@ -1,27 +1,83 @@ from typing import Optional import torch +from torch import Tensor from torch_scatter import gather_csr + from torch_sparse.tensor import SparseTensor -def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor: - rowptr, col, value = src.csr() - if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... - other = gather_csr(other.squeeze(1), rowptr) - pass - elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise... - other = other.squeeze(0)[col] - else: - raise ValueError( - f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' - f'(1, {src.size(1)}, ...), but got size {other.size()}.') +@torch.jit._overload # noqa: F811 +def mul(src, other): # noqa: F811 + # type: (SparseTensor, Tensor) -> SparseTensor + pass - if value is not None: - value = other.to(value.dtype).mul_(value) + +@torch.jit._overload # noqa: F811 +def mul(src, other): # noqa: F811 + # type: (SparseTensor, SparseTensor) -> SparseTensor + pass + + +def mul(src, other): # noqa: F811 + if isinstance(other, Tensor): + rowptr, col, value = src.csr() + if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... + other = gather_csr(other.squeeze(1), rowptr) + pass + # Col-wise... + elif other.size(0) == 1 and other.size(1) == src.size(1): + other = other.squeeze(0)[col] + else: + raise ValueError( + f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' + f'(1, {src.size(1)}, ...), but got size {other.size()}.') + + if value is not None: + value = other.to(value.dtype).mul_(value) + else: + value = other + return src.set_value(value, layout='coo') + + assert isinstance(other, SparseTensor) + + if not src.is_coalesced(): + raise ValueError("The `src` tensor is not coalesced") + if not other.is_coalesced(): + raise ValueError("The `other` tensor is not coalesced") + + rowA, colA, valueA = src.coo() + rowB, colB, valueB = other.coo() + + row = torch.cat([rowA, rowB], dim=0) + col = torch.cat([colA, colB], dim=0) + + if valueA is not None and valueB is not None: + value = torch.cat([valueA, valueB], dim=0) else: - value = other - return src.set_value(value, layout='coo') + raise ValueError('Both sparse tensors must contain values') + + M = max(src.size(0), other.size(0)) + N = max(src.size(1), other.size(1)) + sparse_sizes = (M, N) + + # Sort indices: + idx = col.new_full((col.numel() + 1, ), -1) + idx[1:] = row * sparse_sizes[1] + col + perm = idx[1:].argsort() + idx[1:] = idx[1:][perm] + + row, col, value = row[perm], col[perm], value[perm] + + valid_mask = idx[1:] == idx[:-1] + valid_idx = valid_mask.nonzero().view(-1) + + return SparseTensor( + row=row[valid_mask], + col=col[valid_mask], + value=value[valid_idx - 1] * value[valid_idx], + sparse_sizes=sparse_sizes, + ) def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: @@ -43,8 +99,11 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: return src.set_value_(value, layout='coo') -def mul_nnz(src: SparseTensor, other: torch.Tensor, - layout: Optional[str] = None) -> SparseTensor: +def mul_nnz( + src: SparseTensor, + other: torch.Tensor, + layout: Optional[str] = None, +) -> SparseTensor: value = src.storage.value() if value is not None: value = value.mul(other.to(value.dtype)) @@ -53,8 +112,11 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor, return src.set_value(value, layout=layout) -def mul_nnz_(src: SparseTensor, other: torch.Tensor, - layout: Optional[str] = None) -> SparseTensor: +def mul_nnz_( + src: SparseTensor, + other: torch.Tensor, + layout: Optional[str] = None, +) -> SparseTensor: value = src.storage.value() if value is not None: value = value.mul_(other.to(value.dtype))