Skip to content

Commit

Permalink
Fix aten_cumsum traced mode (#1605)
Browse files Browse the repository at this point in the history
aten_cumsum was not traceable when the input is casted because the shape
information was not propagated. Turning the implementation to use pure
tracing instead.
  • Loading branch information
justinchuby authored Jun 12, 2024
1 parent 165ba5c commit 505e154
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2416,18 +2416,11 @@ def aten_cumsum(
cast = self
else:
cast = op.Cast(self, to=dtype)
return _aten_cumsum_onnx(cast, dim)


@torch_op("aten::cumsum", private=True, traceable=True)
def _aten_cumsum_onnx(
self: TRealUnlessInt16OrInt8, dim: Union[INT32, INT64]
) -> TRealUnlessInt16OrInt8:
if IsScalar(self):
if len(self.shape) == 0:
# A scalar
result = op.Identity(self)
result = op.Identity(cast)
else:
result = op.CumSum(self, dim)
result = op.CumSum(cast, dim)
return result


Expand Down

0 comments on commit 505e154

Please sign in to comment.