diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index a2882d283..4e01d37ac 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -30,7 +30,6 @@ class TestDeduceTypeConstraints(unittest.TestCase): "_aten_embedding_bag_onnx", "_aten_embedding_bag_1d_padding_idx_onnx", ) - _SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ("aten_all",) @parameterized.parameterized.expand( ((op,) for op in torch_lib_onnx_functions_from_registry()), @@ -41,11 +40,13 @@ def test_deduce_type_constraints_does_not_crash_for_onnx_function( ): if onnx_function.name in self._SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN: self.skipTest("Unimplemented: function contains loop or scan node.") - if onnx_function.name in self._SKIP_FUNCTIONS_WITH_NESTED_FUNCTION: - self.skipTest("Unimplemented: function contains nested function.") - signature_type_constraint = deduce_type_constraints.deduce_type_constraints( - onnx_function - ) + try: + signature_type_constraint = deduce_type_constraints.deduce_type_constraints( + onnx_function + ) + except NotImplementedError as e: + if "Nested function" in str(e): + self.skipTest("Unimplemented: function contains nested function.") logger.info( "Original signature: %s%s", onnx_function.name, diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6d770254e..775e149da 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -222,8 +222,8 @@ def aten_addcmul( @torch_op("aten::addmm") def aten_addmm( - self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0 -) -> TReal: + self: TInt, mat1: TInt, mat2: TInt, beta: float = 1.0, alpha: float = 1.0 +) -> TInt: """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor""" mat1_mat2 = op.MatMul(mat1, mat2) @@ -232,6 +232,29 @@ def aten_addmm( return op.Add(scaled_self, scaled_mat1_mat2) +@torch_op("aten::addmm") +def aten_addmm_gemm( + self: TFloat, mat1: TFloat, mat2: TFloat, beta: float = 1.0, alpha: float = 1.0 +) -> TFloat: + """addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor""" + + # A special case when rank of mat1 and mat2 are 2, we can use Gemm instead of MatMul + # We expect the if branches to be folded away by optimization passes + # TODO(#1110): Handle Gemm with a graph rewriting pass instead of hard coding the branching logic here + use_gemm = op.And( + op.Equal(Rank(mat1), op.Constant(value_int=2)), + op.Equal(Rank(mat2), op.Constant(value_int=2)), + ) + if use_gemm: + result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta) + else: + mat1_mat2 = op.MatMul(mat1, mat2) + scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha) + scaled_self = op.Mul(self, beta) + result = op.Add(scaled_self, scaled_mat1_mat2) + return result + + @torch_op("aten::addmv") def aten_addmv( self: TReal, mat: TReal, vec: TReal, beta: float = 1.0, alpha: float = 1.0 @@ -5235,7 +5258,6 @@ def aten_mm( ) -> TRealUnlessInt16OrInt8: """mm(Tensor self, Tensor mat2) -> Tensor""" - # TODO(justinchuby): Specify type conversion for uint8/int8/int16 return op.MatMul(self, mat2) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 7c6a64b49..7304aa8ba 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -489,6 +489,14 @@ def _where_input_wrangler( TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv), TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}), TorchLibOpInfo("addmm", core_ops.aten_addmm), + TorchLibOpInfo("addmm_gemm", core_ops.aten_addmm_gemm).xfail( + "decomposed", + reason=( + "The float attributes alpha/beta come in as int in the test cases, which breaks" + "eager mode. We don't need to care about this as long as the full graph tests pass" + ), + test_class_name="TestOutputConsistencyEager", + ), TorchLibOpInfo("addmv", core_ops.aten_addmv), TorchLibOpInfo( "addr", @@ -1968,6 +1976,7 @@ def _where_input_wrangler( TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True), ) +ops_test_common.duplicate_opinfo(OPS_DB, "addmm", ("addmm_gemm",)) ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))