Skip to content

Commit

Permalink
Use Gemm to implement addmm | fix(torchlib) (#1113)
Browse files Browse the repository at this point in the history
When I looked at the test coverage for `addmm` (below), I realized mat1
and mat2 are always 2d tensors. So the rank check is redundant. `addmm`
is now fully mapped to `Gemm`, which should completely resolve
#1089

Closes #1110


![image](https://github.com/microsoft/onnxscript/assets/11205048/073347ca-d677-4c87-94fa-e40a13642569)
  • Loading branch information
justinchuby authored Oct 25, 2023
1 parent b6ec405 commit c68468e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 31 deletions.
34 changes: 6 additions & 28 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,37 +222,15 @@ def aten_addcmul(

@torch_op("aten::addmm")
def aten_addmm(
self: TInt, mat1: TInt, mat2: TInt, beta: float = 1.0, alpha: float = 1.0
) -> TInt:
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
scaled_self = op.Mul(self, beta)
return op.Add(scaled_self, scaled_mat1_mat2)

# NOTE: ONNX Runtime does not support int inputs to Gemm as of 1.16.
# To support int inputs, consider an overriding implementation that casts to float and back.

@torch_op("aten::addmm")
def aten_addmm_gemm(
self: TFloat, mat1: TFloat, mat2: TFloat, beta: float = 1.0, alpha: float = 1.0
) -> TFloat:
"""addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""

# A special case when rank of mat1 and mat2 are 2, we can use Gemm instead of MatMul
# We expect the if branches to be folded away by optimization passes
# TODO(#1110): Handle Gemm with a graph rewriting pass instead of hard coding the branching logic here
use_gemm = op.And(
op.Equal(Rank(mat1), op.Constant(value_int=2)),
op.Equal(Rank(mat2), op.Constant(value_int=2)),
)
if use_gemm:
result = op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)
else:
mat1_mat2 = op.MatMul(mat1, mat2)
scaled_mat1_mat2 = op.Mul(mat1_mat2, alpha)
scaled_self = op.Mul(self, beta)
result = op.Add(scaled_self, scaled_mat1_mat2)
return result
# addmm only accepts 2d tensors: https://pytorch.org/docs/stable/generated/torch.addmm.html
return op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)


@torch_op("aten::addmv")
Expand Down
14 changes: 11 additions & 3 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,23 @@ def _where_input_wrangler(
TorchLibOpInfo("addbmm", core_ops.aten_addbmm, tolerance={torch.float32: (2e-5, 2e-5)}),
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv),
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
TorchLibOpInfo("addmm", core_ops.aten_addmm),
TorchLibOpInfo("addmm_gemm", core_ops.aten_addmm_gemm).xfail(
TorchLibOpInfo("addmm", core_ops.aten_addmm)
.xfail(
"decomposed",
reason=(
"The float attributes alpha/beta come in as int in the test cases, which breaks"
"eager mode. We don't need to care about this as long as the full graph tests pass"
),
test_class_name="TestOutputConsistencyEager",
)
.xfail(
dtypes=(torch.int16, torch.int32, torch.int64),
reason="ONNX Runtime does not support int inputs to Gemm",
)
.xfail(
"decomposed",
dtypes=(torch.int16, torch.int32, torch.int64),
reason="ONNX Runtime does not support int inputs to Gemm",
),
TorchLibOpInfo("addmv", core_ops.aten_addmv),
TorchLibOpInfo(
Expand Down Expand Up @@ -1976,7 +1985,6 @@ def _where_input_wrangler(
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like, trace_only=True),
)

ops_test_common.duplicate_opinfo(OPS_DB, "addmm", ("addmm_gemm",))
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "arange", ("arange_start", "arange_start_step"))
Expand Down

0 comments on commit c68468e

Please sign in to comment.