Skip to content

Commit

Permalink
Fix baddbmm and scalar_tensor (#1837)
Browse files Browse the repository at this point in the history
1. Handle baddbmm when scalars are SymFloat
2. Accept bool as scalar_tensor input

Fixes justinchuby/torch-onnx#42

---------

Co-authored-by: Ti-Tai Wang <[email protected]>
  • Loading branch information
justinchuby and titaiwangms authored Aug 30, 2024
1 parent 22708e8 commit fac4825
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,20 +1017,25 @@ def reshape_to_3d(tensor):
return op.SequenceMap(self, body=reshape_to_3d)


@torch_op("aten::baddbmm")
@torch_op("aten::baddbmm", trace_only=True)
def aten_baddbmm(
self: TRealOrUInt8,
batch1: TRealUnlessInt16OrInt8,
batch2: TRealUnlessInt16OrInt8,
beta: float = 1.0,
alpha: float = 1.0,
beta: Optional[TFloat] = None,
alpha: Optional[TFloat] = None,
) -> TRealUnlessInt16OrInt8:
"""baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"""
# beta and alpha can be SymFloat
batch_mul = op.MatMul(batch1, batch2)
alpha_cast = op.CastLike(alpha, self)
mul_a = op.Mul(batch_mul, alpha_cast)
beta_cast = op.CastLike(beta, self)
mul_b = op.Mul(self, beta_cast)
if alpha is None or alpha == 1:
mul_a = batch_mul
else:
mul_a = op.Mul(batch_mul, op.CastLike(alpha, self))
if beta is None or beta == 1:
mul_b = self
else:
mul_b = op.Mul(self, op.CastLike(beta, self))
return op.Add(mul_a, mul_b)


Expand Down Expand Up @@ -7413,7 +7418,7 @@ def aten_scalar_tensor_complex(

@torch_op("aten::scalar_tensor", trace_only=True)
def aten_scalar_tensor_sym_number(
s: RealType,
s: TensorType,
dtype: int = FLOAT.dtype,
layout: str = "",
device: str = "",
Expand All @@ -7422,8 +7427,6 @@ def aten_scalar_tensor_sym_number(
"""scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
if dtype == -1:
dtype = FLOAT.dtype
# Set trace_only=True because different if branches return different dtypes
# which is not supported in an ONNX function
return common_ops.cast_to(s, dtype=dtype)


Expand Down

0 comments on commit fac4825

Please sign in to comment.