From 503affb8925ea576337f02da4e9dc52cca50d970 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Fri, 31 May 2024 00:42:56 -0700 Subject: [PATCH] At-operator is matmul not dot (#1140) * At-operator is matmul not dot * Reduce dimensionality of inputs, to work under default max-dim=4 --- cmake/thirdparty/get_tblis.cmake | 4 +-- cunumeric/array.py | 32 ++++++++++++++++++- tests/integration/test_matmul.py | 54 +++++++++++++++++++++++++++++++- 3 files changed, 86 insertions(+), 4 deletions(-) diff --git a/cmake/thirdparty/get_tblis.cmake b/cmake/thirdparty/get_tblis.cmake index b02afbd7d..37433594c 100644 --- a/cmake/thirdparty/get_tblis.cmake +++ b/cmake/thirdparty/get_tblis.cmake @@ -171,11 +171,11 @@ function(find_or_configure_tblis) endfunction() if(NOT DEFINED cunumeric_TBLIS_BRANCH) - set(cunumeric_TBLIS_BRANCH master) + set(cunumeric_TBLIS_BRANCH arm-build) endif() if(NOT DEFINED cunumeric_TBLIS_REPOSITORY) - set(cunumeric_TBLIS_REPOSITORY https://github.com/devinamatthews/tblis.git) + set(cunumeric_TBLIS_REPOSITORY https://github.com/nv-legate/tblis.git) endif() find_or_configure_tblis(VERSION 1.2.0 diff --git a/cunumeric/array.py b/cunumeric/array.py index 3b628ae4d..569a3fbd6 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -1122,6 +1122,20 @@ def __ilshift__(self, rhs: Any) -> ndarray: return left_shift(self, rhs, out=self) + def __imatmul__(self, rhs: Any) -> ndarray: + """a.__imatmul__(value, /) + + Return ``self@=value``. + + Availability + -------- + Multiple GPUs, Multiple CPUs + + """ + from .module import matmul + + return matmul(self, rhs, out=self) + def __imod__(self, rhs: Any) -> ndarray: """a.__imod__(value, /) @@ -1329,7 +1343,9 @@ def __matmul__(self, value: Any) -> ndarray: Multiple GPUs, Multiple CPUs """ - return self.dot(value) + from .module import matmul + + return matmul(self, value) def __mod__(self, rhs: Any) -> ndarray: """a.__mod__(value, /) @@ -1551,6 +1567,20 @@ def __rfloordiv__(self, lhs: Any) -> ndarray: return floor_divide(lhs, self) + def __rmatmul__(self, lhs: Any) -> ndarray: + """a.__rmatmul__(value, /) + + Return ``value@self``. + + Availability + -------- + Multiple GPUs, Multiple CPUs + + """ + from .module import matmul + + return matmul(lhs, self) + def __rmod__(self, lhs: Any) -> ndarray: """a.__rmod__(value, /) diff --git a/tests/integration/test_matmul.py b/tests/integration/test_matmul.py index 66f6ad89a..a4c4e3fd0 100644 --- a/tests/integration/test_matmul.py +++ b/tests/integration/test_matmul.py @@ -16,6 +16,7 @@ import numpy as np import pytest from legate.core import LEGATE_MAX_DIM +from utils.comparisons import allclose from utils.contractions import ( check_default, check_permutations, @@ -29,7 +30,7 @@ @pytest.mark.parametrize("a_ndim", range(1, LEGATE_MAX_DIM + 1)) @pytest.mark.parametrize("b_ndim", range(1, LEGATE_MAX_DIM + 1)) -def test(a_ndim, b_ndim): +def test_function(a_ndim, b_ndim): name = f"matmul({a_ndim} x {b_ndim})" modes = matmul_modes(a_ndim, b_ndim) @@ -43,6 +44,57 @@ def operation(lib, *args, **kwargs): check_types(name, modes, operation) +@pytest.mark.parametrize( + "a_shape", + ( + (3, 4, 5), + (4, 5), + (5,), + ), +) +@pytest.mark.parametrize( + "b_shape", + ( + (3, 5, 6), + (5, 6), + (5,), + ), +) +def test_operator(a_shape, b_shape): + np_a = np.random.random(a_shape) + np_b = np.random.random(b_shape) + num_a = num.array(np_a) + num_b = num.array(np_b) + assert allclose(np_a @ np_b, num_a @ num_b) + + +@pytest.mark.parametrize( + "a_shape", + ( + (3, 4, 5), + (4, 5), + (5,), + ), +) +@pytest.mark.parametrize( + "b_shape", + ( + (3, 5, 5), + (5, 5), + ), +) +def test_inplace_operator(a_shape, b_shape): + if len(a_shape) < len(b_shape): + return + np_a = np.random.random(a_shape) + np_b = np.random.random(b_shape) + num_a = num.array(np_a) + num_b = num.array(np_b) + np_a @= np_b + num_a @= num_b + assert allclose(np_a, num_a) + + class TestMatmulErrors: @pytest.mark.parametrize( "shapesAB",