Skip to content

Commit

Permalink
Trace Op (aten::addmm) | feat(torchlib) (#1825)
Browse files Browse the repository at this point in the history
addmm is used a lot, and it's not traced yet. We trace it for better
debugging and graph experience.
  • Loading branch information
titaiwangms authored Aug 26, 2024
1 parent add9558 commit 63b1cdb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
5 changes: 4 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def aten_addcmul(
return op.Add(self, op.Mul(op.Mul(value, tensor1), tensor2))


@torch_op("aten::addmm")
@torch_op("aten::addmm", trace_only=True)
def aten_addmm(
self: TReal, mat1: TReal, mat2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
Expand All @@ -247,6 +247,9 @@ def aten_addmm(
# NOTE: ONNX Runtime does not support int inputs to Gemm as of 1.16.
# To support int inputs, consider an overriding implementation that casts to float and back.

alpha = float(alpha)
beta = float(beta)

# addmm only accepts 2d tensors: https://pytorch.org/docs/stable/generated/torch.addmm.html
return op.Gemm(mat1, mat2, self, alpha=alpha, beta=beta)

Expand Down
8 changes: 0 additions & 8 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,14 +546,6 @@ def _where_input_wrangler(
TorchLibOpInfo("addcdiv", core_ops.aten_addcdiv, tolerance={torch.float16: (3e-2, 1e-3)}),
TorchLibOpInfo("addcmul", core_ops.aten_addcmul, tolerance={torch.float16: (4e-3, 3e-3)}),
TorchLibOpInfo("addmm", core_ops.aten_addmm)
.xfail(
"decomposed",
reason=(
"The float attributes alpha/beta come in as int in the test cases, which breaks"
"eager mode. We don't need to care about this as long as the full graph tests pass"
),
test_class_name="TestOutputConsistencyEager",
)
.xfail(
dtypes=(torch.int16, torch.int32, torch.int64),
reason="ONNX Runtime does not support int inputs to Gemm",
Expand Down

0 comments on commit 63b1cdb

Please sign in to comment.