Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxscript into bbb
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 28, 2024
2 parents 5735461 + be00339 commit 7019d1d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
50 changes: 25 additions & 25 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def aten_affine_grid_generator_backward(
def aten_alias(self: TTensor) -> TTensor:
"""alias(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


def aten_alias_copy(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -374,7 +374,7 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
self = aten_all_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self
return op.Identity(self)


@torch_op("aten::all.dims", traceable=True)
Expand Down Expand Up @@ -499,7 +499,7 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False)
self = aten_any_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self
return op.Identity(self)


@torch_op("aten::any.dims", traceable=True)
Expand Down Expand Up @@ -940,7 +940,7 @@ def aten_atleast_1d(self: TTensor) -> TTensor:

if IsScalar(self):
self = op.Reshape(self, op.Constant(value_ints=[1]))
return self
return op.Identity(self)


@torch_op("aten::atleast_1d.Sequence")
Expand All @@ -964,7 +964,7 @@ def aten_atleast_2d(self: TTensor) -> TTensor:

if Rank(self) <= 1:
self = op.Reshape(self, op.Constant(value_ints=[1, -1]))
return self
return op.Identity(self)


@torch_op("aten::atleast_2d.Sequence")
Expand All @@ -991,7 +991,7 @@ def aten_atleast_3d(self: TTensor) -> TTensor:
self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1]))
elif rank == 2:
self = op.Unsqueeze(self, op.Constant(value_ints=[-1]))
return self
return op.Identity(self)


@torch_op("aten::atleast_3d.Sequence")
Expand Down Expand Up @@ -1691,7 +1691,7 @@ def aten_clone(
) -> TTensor:
"""clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor"""

return self
return op.Identity(self)


def aten_coalesce(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -1749,7 +1749,7 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
def aten_conj(self: TTensor) -> TTensor:
"""conj(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


@torch_op("aten::conj", complex=True, private=True)
Expand Down Expand Up @@ -1825,7 +1825,7 @@ def aten_contiguous(
"""contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)"""

# ONNX does not have the notion of memory_format. It is always treated as a no-op.
return self
return op.Identity(self)


@torch_op("aten::conv1d", trace_only=True)
Expand Down Expand Up @@ -2168,7 +2168,7 @@ def aten__to_copy(
"""_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor"""

if dtype == -1:
return self
return op.Identity(self)
else:
return common_ops.cast_to(self, dtype=dtype)

Expand Down Expand Up @@ -2493,7 +2493,7 @@ def aten_dense_dim(self: TensorType) -> int:
def aten_detach(self: TensorType) -> TensorType:
"""detach(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


def aten_detach_copy(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -4061,7 +4061,7 @@ def _aten_index_onnx(
if _has_none_in_middle(indices):
# If there is None in the middle, Advanced Indexing cannot decide where to put
# the new dimensions. So it places them in the front, like GatherND does.
return self
return op.Identity(self)

# When the indices are consecutive, Advanced Indexing will place the new dimensions
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
Expand Down Expand Up @@ -4227,7 +4227,7 @@ def aten_index_put_bool(
index = op.SequenceAt(indices, 0) # assume indices only have 1 element
# FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it
index_int = op.Cast(index, to=INT32.dtype)
# if all False, return self
# if all False, return op.Identity(self)
if op.ReduceSum(index_int) == 0:
result = self
else:
Expand Down Expand Up @@ -4700,7 +4700,7 @@ def aten_lift_fresh(self: TensorType) -> TensorType:
def aten_lift_fresh_copy(self: TensorType) -> TensorType:
"""lift_fresh_copy(Tensor self) -> Tensor"""

return self
return op.Identity(self)


def aten_linear_backward(
Expand Down Expand Up @@ -7082,14 +7082,14 @@ def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType:
def aten_resolve_conj(self: TTensor) -> TTensor:
"""resolve_conj(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


@torch_op("aten::resolve_neg", trace_only=True)
def aten_resolve_neg(self: TTensor) -> TTensor:
"""resolve_neg(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


def aten_result_type(tensor: TensorType, other: TensorType) -> int:
Expand Down Expand Up @@ -7142,9 +7142,9 @@ def aten_roll(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor

self_rank = len(self.shape)
if self_rank == 0:
return self
return op.Identity(self)
elif self.shape[0] == 0: # empty tensor
return self
return op.Identity(self)
else:
# NOTE: In pytorch, default value of dims is an empty list.
if len(dims) == 0: # Empty sequence
Expand All @@ -7166,10 +7166,10 @@ def aten_roll_complex(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) ->

self_rank = len(self.shape)
if self_rank == 1:
return self
return op.Identity(self)

if self.shape[0] == 0: # empty tensor
return self
return op.Identity(self)

self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
Expand Down Expand Up @@ -7819,7 +7819,7 @@ def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT6
if signal_rank == 1:
# Add a batch dimension
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
return self, signal_rank
return op.Identity(self), signal_rank


@torch_op("aten::stft", private=True)
Expand Down Expand Up @@ -8768,7 +8768,7 @@ def aten_view_as_complex(self: TTensor) -> TTensor:

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_as_complex_copy", trace_only=True)
Expand All @@ -8777,7 +8777,7 @@ def aten_view_as_complex_copy(self: TTensor) -> TTensor:

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_as_real", complex=True, trace_only=True)
Expand All @@ -8786,7 +8786,7 @@ def aten_view_as_real(self: TTensor) -> TTensor:

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_as_real_copy", complex=True, trace_only=True)
Expand All @@ -8795,7 +8795,7 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor:

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_copy")
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/rewriter/llama_rule_sets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_llama_p0_rule_set_cast_cast(self):
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(["Cast"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model, atol=1e-3)
self._check_model(model_proto, rewritten_model, atol=1e-2)

@classmethod
def _cast_identity_models(cls):
Expand Down Expand Up @@ -376,6 +376,7 @@ def _slides_split_models(cls):
]
return models

@unittest.skipIf(True, reason="see https://github.com/microsoft/onnxscript/issues/1642")
def test_llama_p0_rule_set_slice_split(self):
for model_proto in self._slides_split_models():
ir_model = ir.serde.deserialize_model(model_proto)
Expand Down

0 comments on commit 7019d1d

Please sign in to comment.