Skip to content

Commit

Permalink
Decompose addmm with Gemm | feat(torchlib) (#1111)
Browse files Browse the repository at this point in the history
Decompose addmm with Gemm by creating a special variant for `FLOAT` and
conditionally check for the ranks if the input tensors. The if branch is
expected to be folded away by constant folding passes.

I have not found other instances where Gemm is used in the torch.onnx
exporter.

Fixes #1089
  • Loading branch information
justinchuby authored Oct 25, 2023
1 parent 9fb0a7d commit b6ec405
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class TestDeduceTypeConstraints(unittest.TestCase):
"_aten_embedding_bag_onnx",
"_aten_embedding_bag_1d_padding_idx_onnx",
)
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ("aten_all",)

@parameterized.parameterized.expand(
((op,) for op in torch_lib_onnx_functions_from_registry()),
Expand All @@ -41,11 +40,13 @@ def test_deduce_type_constraints_does_not_crash_for_onnx_function(
):
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN:
self.skipTest("Unimplemented: function contains loop or scan node.")
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_NESTED_FUNCTION:
self.skipTest("Unimplemented: function contains nested function.")
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
onnx_function
)
try:
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
onnx_function
)
except NotImplementedError as e:
if "Nested function" in str(e):
self.skipTest("Unimplemented: function contains nested function.")
logger.info(
"Original signature: %s%s",
onnx_function.name,
Expand Down
28 changes: 25 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def aten_addcmul(

@torch_op("aten::addmm")
def aten_addmm(
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
self: TInt, mat1: TInt, mat2: TInt, beta: float = 1.0, alpha: float = 1.0
) -> TInt:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

mat1_mat2 = op.MatMul(mat1, mat2)
Expand All @@ -232,6 +232,29 @@ def aten_addmm(
return op.Add(scaled_self, scaled_mat1_mat2)


@torch_op("aten::addmm")
def aten_addmm_gemm(
self: TFloat, mat1: TFloat, mat2: TFloat, beta: float = 1.0, alpha: float = 1.0
) -> TFloat:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

# A special case when rank of mat1 and mat2 are 2, we can use Gemm instead of MatMul
# We expect the if branches to be folded away by optimization passes
# TODO(#1110): Handle Gemm with a graph rewriting pass instead of hard coding the branching logic here
use_gemm = op.And(
op.Equal(Rank(mat1), op.Constant(value_int=2)),
op.Equal(Rank(mat2), op.Constant(value_int=2)),
)
if use_gemm:
result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
else:
mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
scaled_self = op.Mul(self, beta)
result = op.Add(scaled_self, scaled_mat1_mat2)
return result


@torch_op("aten::addmv")
def aten_addmv(
self: TReal, mat: TReal, vec: TReal, beta: float = 1.0, alpha: float = 1.0
Expand Down Expand Up @@ -5235,7 +5258,6 @@ def aten_mm(
) -> TRealUnlessInt16OrInt8:
"""mm(Tensor self, Tensor mat2) -> Tensor"""

# TODO(justinchuby): Specify type conversion for uint8/int8/int16
return op.MatMul(self, mat2)


Expand Down
9 changes: 9 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,14 @@ def _where_input_wrangler(
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
TorchLibOpInfo("addmm", core_ops.aten_addmm),
TorchLibOpInfo("addmm_gemm", core_ops.aten_addmm_gemm).xfail(
"decomposed",
reason=(
"The float attributes alpha/beta come in as int in the test cases, which breaks"
"eager mode. We don't need to care about this as long as the full graph tests pass"
),
test_class_name="TestOutputConsistencyEager",
),
TorchLibOpInfo("addmv", core_ops.aten_addmv),
TorchLibOpInfo(
"addr",
Expand Down Expand Up @@ -1968,6 +1976,7 @@ def _where_input_wrangler(
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
)

ops_test_common.duplicate_opinfo(OPS_DB, "addmm", ("addmm_gemm",))
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
Expand Down

0 comments on commit b6ec405

Please sign in to comment.