Skip to content

Commit

Permalink
[torchlib] Add the identity nodes back (#1703)
Browse files Browse the repository at this point in the history
In the modularization pass in the exporter, a single node like `clone`
can be lifted as a function. If we remove the only Identity node the
lifted function will have no nodes. This violates the ONNX standard.

Since removing identity nodes is fast, we are safe to include these
identity nodes in the torchlib.

onnxscript/tools/transformers_models/phi_test.py broke after #1613, it
is fixed by this change.

---------

Signed-off-by: Xavier Dupre <[email protected]>
Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
xadupre and justinchuby authored Jun 25, 2024
1 parent 1aa7a70 commit be00339
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
50 changes: 25 additions & 25 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def aten_affine_grid_generator_backward(
def aten_alias(self: TTensor) -> TTensor:
"""alias(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


def aten_alias_copy(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -374,7 +374,7 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
self = aten_all_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self
return op.Identity(self)


@torch_op("aten::all.dims", traceable=True)
Expand Down Expand Up @@ -499,7 +499,7 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
self = aten_any_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self
return op.Identity(self)


@torch_op("aten::any.dims", traceable=True)
Expand Down Expand Up @@ -940,7 +940,7 @@ def aten_atleast_1d(self: TTensor) -> TTensor:

if IsScalar(self):
self = op.Reshape(self, op.Constant(value_ints=[1]))
return self
return op.Identity(self)


@torch_op("aten::atleast_1d.Sequence")
Expand All @@ -964,7 +964,7 @@ def aten_atleast_2d(self: TTensor) -> TTensor:

if Rank(self) <= 1:
self = op.Reshape(self, op.Constant(value_ints=[1, -1]))
return self
return op.Identity(self)


@torch_op("aten::atleast_2d.Sequence")
Expand All @@ -991,7 +991,7 @@ def aten_atleast_3d(self: TTensor) -> TTensor:
self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1]))
elif rank == 2:
self = op.Unsqueeze(self, op.Constant(value_ints=[-1]))
return self
return op.Identity(self)


@torch_op("aten::atleast_3d.Sequence")
Expand Down Expand Up @@ -1691,7 +1691,7 @@ def aten_clone(
) -> TTensor:
"""clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor"""

return self
return op.Identity(self)


def aten_coalesce(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -1749,7 +1749,7 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
def aten_conj(self: TTensor) -> TTensor:
"""conj(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


@torch_op("aten::conj", complex=True, private=True)
Expand Down Expand Up @@ -1825,7 +1825,7 @@ def aten_contiguous(
"""contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)"""

# ONNX does not have the notion of memory_format. It is always treated as a no-op.
return self
return op.Identity(self)


@torch_op("aten::conv1d", trace_only=True)
Expand Down Expand Up @@ -2168,7 +2168,7 @@ def aten__to_copy(
"""_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor"""

if dtype == -1:
return self
return op.Identity(self)
else:
return common_ops.cast_to(self, dtype=dtype)

Expand Down Expand Up @@ -2493,7 +2493,7 @@ def aten_dense_dim(self: TensorType) -> int:
def aten_detach(self: TensorType) -> TensorType:
"""detach(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


def aten_detach_copy(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -4061,7 +4061,7 @@ def _aten_index_onnx(
if _has_none_in_middle(indices):
# If there is None in the middle, Advanced Indexing cannot decide where to put
# the new dimensions. So it places them in the front, like GatherND does.
return self
return op.Identity(self)

# When the indices are consecutive, Advanced Indexing will place the new dimensions
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
Expand Down Expand Up @@ -4227,7 +4227,7 @@ def aten_index_put_bool(
index = op.SequenceAt(indices, 0) # assume indices only have 1 element
# FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it
index_int = op.Cast(index, to=INT32.dtype)
# if all False, return self
# if all False, return op.Identity(self)
if op.ReduceSum(index_int) == 0:
result = self
else:
Expand Down Expand Up @@ -4700,7 +4700,7 @@ def aten_lift_fresh(self: TensorType) -> TensorType:
def aten_lift_fresh_copy(self: TensorType) -> TensorType:
"""lift_fresh_copy(Tensor self) -> Tensor"""

return self
return op.Identity(self)


def aten_linear_backward(
Expand Down Expand Up @@ -7082,14 +7082,14 @@ def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType:
def aten_resolve_conj(self: TTensor) -> TTensor:
"""resolve_conj(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


@torch_op("aten::resolve_neg", trace_only=True)
def aten_resolve_neg(self: TTensor) -> TTensor:
"""resolve_neg(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


def aten_result_type(tensor: TensorType, other: TensorType) -> int:
Expand Down Expand Up @@ -7142,9 +7142,9 @@ def aten_roll(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor

self_rank = len(self.shape)
if self_rank == 0:
return self
return op.Identity(self)
elif self.shape[0] == 0: # empty tensor
return self
return op.Identity(self)
else:
# NOTE: In pytorch, default value of dims is an empty list.
if len(dims) == 0: # Empty sequence
Expand All @@ -7166,10 +7166,10 @@ def aten_roll_complex(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) ->

self_rank = len(self.shape)
if self_rank == 1:
return self
return op.Identity(self)

if self.shape[0] == 0: # empty tensor
return self
return op.Identity(self)

self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
Expand Down Expand Up @@ -7819,7 +7819,7 @@ def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT6
if signal_rank == 1:
# Add a batch dimension
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
return self, signal_rank
return op.Identity(self), signal_rank


@torch_op("aten::stft", private=True)
Expand Down Expand Up @@ -8768,7 +8768,7 @@ def aten_view_as_complex(self: TTensor) -> TTensor:

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_as_complex_copy", trace_only=True)
Expand All @@ -8777,7 +8777,7 @@ def aten_view_as_complex_copy(self: TTensor) -> TTensor:

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_as_real", complex=True, trace_only=True)
Expand All @@ -8786,7 +8786,7 @@ def aten_view_as_real(self: TTensor) -> TTensor:

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_as_real_copy", complex=True, trace_only=True)
Expand All @@ -8795,7 +8795,7 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor:

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_copy")
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/rewriter/llama_rule_sets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_llama_p0_rule_set_cast_cast(self):
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(["Cast"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model, atol=1e-3)
self._check_model(model_proto, rewritten_model, atol=1e-2)

@classmethod
def _cast_identity_models(cls):
Expand Down Expand Up @@ -376,6 +376,7 @@ def _slides_split_models(cls):
]
return models

@unittest.skipIf(True, reason="see https://github.com/microsoft/onnxscript/issues/1642")
def test_llama_p0_rule_set_slice_split(self):
for model_proto in self._slides_split_models():
ir_model = ir.serde.deserialize_model(model_proto)
Expand Down

0 comments on commit be00339

Please sign in to comment.