-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[rewriter] Remove redundant op.Slice and op.ScatterND (#1925)
Fixes microsoft/onnx-converters-private#270 It is observed that ExportedProgram could generates aten::slice.Tensor and aten::slice_scatter.default that slice nothing: ![Screenshot 2024-10-29 112157](https://github.com/user-attachments/assets/6274e71c-f5a8-4fdc-b885-ff1365b4c245) The slices would result in redundant op.Slice ops in ONNX graph that does nothing, and op.ScatterND that basically replaces the whole input to updates, which takes a lot of time in inference. This rule set recognizes the redundant slices by checking if the following requirements are met: (1) starts = 0 (2) ends >= inputs[dim].shape or ends == _INT64_MAX (3) steps == 1 This rule set recognizes the redundant scatterND by checking if the following requirements are met: (1) indices has the same length as the first dim of input (2) indices is from 0 to input.shape[0] (3)input has the same shape as updates Benchmark on ghostnet_100 (the original speed up was 0.0256): || Stat | Speedup | Increase | Med | |-------|---------------|------------|-----------|-----------| | Suite | Model Name | onnx_dynamo | onnx_dynamo | onnx_dynamo | | Timm | ghostnet_100 | 1.1599 | 15.986% | 1.1580 |
- Loading branch information
1 parent
c13e4fd
commit 1ceb85b
Showing
5 changed files
with
317 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import logging | ||
|
||
from onnxscript import ir | ||
from onnxscript.rewriter import pattern | ||
|
||
logger = logging.getLogger(__name__) | ||
_INT64_MAX = 9223372036854775807 | ||
|
||
|
||
def _check_if_redundant_slice( | ||
context, | ||
data: ir.Value, | ||
starts: ir.Value, | ||
ends: ir.Value, | ||
axes: ir.Value, | ||
steps: ir.Value, | ||
**_, | ||
) -> bool: | ||
"""If the starts is 0, and the ends is equal to or grater than the shape of the specified axis, then the slice is redundant.""" | ||
del context # Reserved for future extensions | ||
|
||
starts_const = starts.const_value | ||
ends_const = ends.const_value | ||
axes_const = axes.const_value | ||
steps_const = steps.const_value | ||
|
||
# Check if the values are scalar | ||
if starts_const.numpy().size != 1: # type: ignore[union-attr] | ||
logger.info("The value 'start' is not a scalar.") | ||
return False | ||
if ends_const.numpy().size != 1: # type: ignore[union-attr] | ||
logger.info("The value 'end' is not a scalar.") | ||
return False | ||
if axes_const.numpy().size != 1: # type: ignore[union-attr] | ||
logger.info("The value 'axis' is not a scalar.") | ||
return False | ||
if steps_const.numpy().size != 1: # type: ignore[union-attr] | ||
logger.info("The value 'step' is not a scalar.") | ||
return False | ||
|
||
if starts_const is None or ends_const is None or axes_const is None or steps_const is None: | ||
logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.") | ||
return False | ||
if steps_const.numpy().item() != 1: | ||
logger.info("The value 'step' is not 1.") | ||
return False | ||
# starts is 0 | ||
if starts_const.numpy().item() != 0: | ||
logger.info("The value 'start' is not 0.") | ||
return False | ||
# In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize | ||
if ends_const.numpy().item() == _INT64_MAX: | ||
return True | ||
if data.shape is None: | ||
logger.info("The value 'data' shape is not statically known.") | ||
return False | ||
if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]: | ||
logger.info("The value 'end' is less than the shape of the specified axis.") | ||
return False | ||
|
||
return True | ||
|
||
|
||
def _identity_to_itself(op, data, **_): | ||
"""Return the input data as the output.""" | ||
return op.Identity(data) | ||
|
||
|
||
def _identity_to_updates(op, data, indices, updates, **_): | ||
"""Return the updates as the output. | ||
This is used when the ScatterND is redundant in terms of | ||
updating the whole data with the updates. | ||
""" | ||
return op.Identity(updates) | ||
|
||
|
||
def _potential_redundant_slice(op, data, starts, ends, axes, steps): | ||
"""To identify a slice op""" | ||
return op.Slice(data, starts, ends, axes, steps) | ||
|
||
|
||
def _potential_redundant_scatternd(op, data, indices, updates): | ||
"""To identify a ScatterND op""" | ||
return op.ScatterND(data, indices, updates) | ||
|
||
|
||
def _check_if_redundant_scatternd( | ||
context, | ||
data: ir.Value, | ||
indices: ir.Value, | ||
updates: ir.Value, | ||
**_, | ||
): | ||
"""If the indices is the same length as the first dim of data, and the shape of updates is equal to data, we can simply swap the whole value.""" | ||
del context # Reserved for future extensions | ||
|
||
# To validate data can be replaced directly by updates, we need to check the following: | ||
# 1. they have the same shape | ||
if data.shape is None: | ||
logger.info("The value 'data' shape is not statically known.") | ||
return False | ||
if updates.shape is None: | ||
logger.info("The value 'updates' shape is not statically known.") | ||
return False | ||
if data.shape != updates.shape: | ||
logger.info("The shape of 'data' and 'updates' are different.") | ||
return False | ||
|
||
# 2. the indices is referring to the whole data, which is from 0 to data.shape[0] | ||
if indices.const_value is None: | ||
logger.info("The value 'indices' is not statically known.") | ||
return False | ||
if indices.const_value.numpy().tolist() != [[i] for i in range(data.shape[0])]: # type: ignore[arg-type] | ||
logger.info("The 'indices' is not referring to the whole data.") | ||
return False | ||
|
||
return True | ||
|
||
|
||
# Register the rewrite rules | ||
remove_redundant_slice = pattern.RewriteRule( | ||
_potential_redundant_slice, | ||
_identity_to_itself, | ||
_check_if_redundant_slice, | ||
) | ||
|
||
remove_redundant_scatternd = pattern.RewriteRule( | ||
_potential_redundant_scatternd, | ||
_identity_to_updates, | ||
_check_if_redundant_scatternd, | ||
) | ||
|
||
# NOTE: The order of the rules is important. Larger pattern should be checked first. | ||
rules = pattern.RewriteRuleSet([remove_redundant_slice, remove_redundant_scatternd]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
import onnx.parser | ||
import onnx.shape_inference | ||
|
||
from onnxscript import ir | ||
from onnxscript.rewriter import collapse_slices, testing | ||
|
||
_INT64_MAX = 9223372036854775807 | ||
|
||
|
||
class TwoReshapesMatMulReshapeTest(unittest.TestCase): | ||
def test_slice_is_redundant_when_ends_is_greater_than_input_shape(self): | ||
model_proto = onnx.parser.parse_model( | ||
""" | ||
<ir_version: 7, opset_import: [ "" : 17]> | ||
agraph (float[512, 16, 112] data) => (float[512, 16, 112] output) | ||
{ | ||
starts = Constant<value: tensor = int64[1] {0}>() | ||
ends = Constant<value: tensor = int64[1] {9999}>() | ||
axes = Constant<value: tensor = int64[1] {0}>() | ||
steps = Constant<value: tensor = int64[1] {1}>() | ||
output = Slice (data, starts, ends, axes, steps) | ||
} | ||
""" | ||
) | ||
model = ir.serde.deserialize_model(model_proto) | ||
count = collapse_slices.rules.apply_to_model(model) | ||
self.assertEqual(count, 1) | ||
self.assertEqual(len(model.graph), 5) | ||
self.assertIn("Identity", [node.op_type for node in model.graph]) | ||
testing.assert_numerically_equal( | ||
model_proto, | ||
model, | ||
(np.random.rand(512, 16, 112).astype(np.float32),), | ||
) | ||
|
||
def test_slice_is_redundant_when_ends_reaches_int64_max(self): | ||
model_proto = onnx.parser.parse_model( | ||
f""" | ||
<ir_version: 7, opset_import: [ "" : 17]> | ||
agraph (float[512, 16, 112] data) => (float[512, 16, 112] output) | ||
{{ | ||
starts = Constant<value: tensor = int64[1] {{0}}>() | ||
ends = Constant<value: tensor = int64[1] {{{_INT64_MAX}}}>() | ||
axes = Constant<value: tensor = int64[1] {{0}}>() | ||
steps = Constant<value: tensor = int64[1] {{1}}>() | ||
output = Slice (data, starts, ends, axes, steps) | ||
}} | ||
""" | ||
) | ||
model = ir.serde.deserialize_model(model_proto) | ||
count = collapse_slices.rules.apply_to_model(model) | ||
self.assertEqual(count, 1) | ||
self.assertEqual(len(model.graph), 5) | ||
self.assertIn("Identity", [node.op_type for node in model.graph]) | ||
testing.assert_numerically_equal( | ||
model_proto, | ||
model, | ||
(np.random.rand(512, 16, 112).astype(np.float32),), | ||
) | ||
|
||
def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(self): | ||
model_proto = onnx.parser.parse_model( | ||
""" | ||
<ir_version: 7, opset_import: [ "" : 17]> | ||
agraph (float[112, 16, 512] data, float[112, 16, 512] updates) => (float[112, 16, 512] output) | ||
{ | ||
output = ScatterND (data, indices, updates) | ||
} | ||
""" | ||
) | ||
# Use inserted initializers to avoid manually coding the large constants | ||
indices = np.arange(112).reshape(112, 1) | ||
model = ir.serde.deserialize_model(model_proto) | ||
# from numpy to ir.Tensor | ||
indices_ir_tensor = ir.Tensor( | ||
name="indices", | ||
value=indices, | ||
) | ||
# assign the tensor to a value | ||
indices = model.graph[0].inputs[1] | ||
indices.const_value = indices_ir_tensor | ||
model.graph.initializers["indices"] = indices | ||
original_model_proto = ir.serde.serialize_model(model) | ||
|
||
count = collapse_slices.rules.apply_to_model(model) | ||
self.assertEqual(count, 1) | ||
self.assertEqual(len(model.graph), 1) | ||
self.assertIn("Identity", [node.op_type for node in model.graph]) | ||
|
||
input = np.random.rand(112, 16, 512).astype(np.float32) | ||
testing.assert_numerically_equal(original_model_proto, model, (input, input)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from typing import Any | ||
|
||
import numpy as np | ||
import onnx | ||
import onnxruntime as ort | ||
|
||
from onnxscript import ir | ||
|
||
|
||
def assert_numerically_equal( | ||
original_model_proto: onnx.ModelProto | ir.Model, | ||
rewritten_model_proto: onnx.ModelProto | ir.Model, | ||
args: tuple[Any, ...], | ||
rtol: float = 1, | ||
atol: float = 1e-3, | ||
): | ||
"""Assert that the two models are numerically equal. | ||
Args: | ||
original_model_proto: The original model proto or ir.Model. | ||
rewritten_model_proto: The rewritten by the rules model proto or ir.Model. | ||
rtol: Relative tolerance. | ||
atol: Absolute tolerance. | ||
args: The positional arguments to pass to the model. | ||
""" | ||
|
||
if isinstance(original_model_proto, ir.Model): | ||
original_model_proto = ir.serde.serialize_model(original_model_proto) | ||
if isinstance(rewritten_model_proto, ir.Model): | ||
rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto) | ||
|
||
original_proto_ort_inputs = { | ||
k.name: v for k, v in zip(original_model_proto.graph.input, args) | ||
} | ||
original_proto_ort_inference_session = _ort_session_initializer( | ||
original_model_proto.SerializeToString() | ||
) | ||
run_options = ort.RunOptions() | ||
run_options.log_severity_level = 3 # 3: Error | ||
original_outputs = original_proto_ort_inference_session.run( | ||
None, original_proto_ort_inputs, run_options=run_options | ||
) | ||
|
||
the_rewritten_proto_ort_inputs = { | ||
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args) | ||
} | ||
the_rewritten_proto_ort_inference_session = _ort_session_initializer( | ||
rewritten_model_proto.SerializeToString() | ||
) | ||
the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( | ||
None, the_rewritten_proto_ort_inputs, run_options=run_options | ||
) | ||
|
||
np.testing.assert_allclose( | ||
original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True | ||
) | ||
|
||
|
||
def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession: | ||
"""Initialize an ONNX Runtime inference session with the specified model.""" | ||
import onnxruntime as ort | ||
|
||
session_options = ort.SessionOptions() | ||
session_options.log_severity_level = 3 # 3: Error | ||
possible_providers = ( | ||
"CUDAExecutionProvider", | ||
"CPUExecutionProvider", | ||
) | ||
available_providers = set(ort.get_available_providers()) | ||
providers = [ | ||
provider for provider in possible_providers if provider in available_providers | ||
] | ||
return ort.InferenceSession(model, providers=providers, sess_options=session_options) |