Skip to content

Commit

Permalink
At-operator is matmul not dot (#1140)
Browse files Browse the repository at this point in the history
* At-operator is matmul not dot

* Reduce dimensionality of inputs, to work under default max-dim=4
  • Loading branch information
manopapad authored May 31, 2024
1 parent b1cff77 commit 503affb
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
4 changes: 2 additions & 2 deletions cmake/thirdparty/get_tblis.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ function(find_or_configure_tblis)
endfunction()

if(NOT DEFINED cunumeric_TBLIS_BRANCH)
set(cunumeric_TBLIS_BRANCH master)
set(cunumeric_TBLIS_BRANCH arm-build)
endif()

if(NOT DEFINED cunumeric_TBLIS_REPOSITORY)
set(cunumeric_TBLIS_REPOSITORY https://github.com/devinamatthews/tblis.git)
set(cunumeric_TBLIS_REPOSITORY https://github.com/nv-legate/tblis.git)
endif()

find_or_configure_tblis(VERSION 1.2.0
Expand Down
32 changes: 31 additions & 1 deletion cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,20 @@ def __ilshift__(self, rhs: Any) -> ndarray:

return left_shift(self, rhs, out=self)

def __imatmul__(self, rhs: Any) -> ndarray:
"""a.__imatmul__(value, /)
Return ``self@=value``.
Availability
--------
Multiple GPUs, Multiple CPUs
"""
from .module import matmul

return matmul(self, rhs, out=self)

def __imod__(self, rhs: Any) -> ndarray:
"""a.__imod__(value, /)
Expand Down Expand Up @@ -1329,7 +1343,9 @@ def __matmul__(self, value: Any) -> ndarray:
Multiple GPUs, Multiple CPUs
"""
return self.dot(value)
from .module import matmul

return matmul(self, value)

def __mod__(self, rhs: Any) -> ndarray:
"""a.__mod__(value, /)
Expand Down Expand Up @@ -1551,6 +1567,20 @@ def __rfloordiv__(self, lhs: Any) -> ndarray:

return floor_divide(lhs, self)

def __rmatmul__(self, lhs: Any) -> ndarray:
"""a.__rmatmul__(value, /)
Return ``value@self``.
Availability
--------
Multiple GPUs, Multiple CPUs
"""
from .module import matmul

return matmul(lhs, self)

def __rmod__(self, lhs: Any) -> ndarray:
"""a.__rmod__(value, /)
Expand Down
54 changes: 53 additions & 1 deletion tests/integration/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import pytest
from legate.core import LEGATE_MAX_DIM
from utils.comparisons import allclose
from utils.contractions import (
check_default,
check_permutations,
Expand All @@ -29,7 +30,7 @@

@pytest.mark.parametrize("a_ndim", range(1, LEGATE_MAX_DIM + 1))
@pytest.mark.parametrize("b_ndim", range(1, LEGATE_MAX_DIM + 1))
def test(a_ndim, b_ndim):
def test_function(a_ndim, b_ndim):
name = f"matmul({a_ndim} x {b_ndim})"
modes = matmul_modes(a_ndim, b_ndim)

Expand All @@ -43,6 +44,57 @@ def operation(lib, *args, **kwargs):
check_types(name, modes, operation)


@pytest.mark.parametrize(
"a_shape",
(
(3, 4, 5),
(4, 5),
(5,),
),
)
@pytest.mark.parametrize(
"b_shape",
(
(3, 5, 6),
(5, 6),
(5,),
),
)
def test_operator(a_shape, b_shape):
np_a = np.random.random(a_shape)
np_b = np.random.random(b_shape)
num_a = num.array(np_a)
num_b = num.array(np_b)
assert allclose(np_a @ np_b, num_a @ num_b)


@pytest.mark.parametrize(
"a_shape",
(
(3, 4, 5),
(4, 5),
(5,),
),
)
@pytest.mark.parametrize(
"b_shape",
(
(3, 5, 5),
(5, 5),
),
)
def test_inplace_operator(a_shape, b_shape):
if len(a_shape) < len(b_shape):
return
np_a = np.random.random(a_shape)
np_b = np.random.random(b_shape)
num_a = num.array(np_a)
num_b = num.array(np_b)
np_a @= np_b
num_a @= num_b
assert allclose(np_a, num_a)


class TestMatmulErrors:
@pytest.mark.parametrize(
"shapesAB",
Expand Down

0 comments on commit 503affb

Please sign in to comment.