Skip to content

Commit

Permalink
[rewriter] Remove redundant op.Slice and op.ScatterND (#1925)
Browse files Browse the repository at this point in the history
Fixes microsoft/onnx-converters-private#270

It is observed that ExportedProgram could generates aten::slice.Tensor
and aten::slice_scatter.default that slice nothing:

![Screenshot 2024-10-29
112157](https://github.com/user-attachments/assets/6274e71c-f5a8-4fdc-b885-ff1365b4c245)

The slices would result in redundant op.Slice ops in ONNX graph that
does nothing, and op.ScatterND that basically replaces the whole input
to updates, which takes a lot of time in inference.

This rule set recognizes the redundant slices by checking if the
following requirements are met:
(1) starts = 0
(2) ends >= inputs[dim].shape or ends == _INT64_MAX 
(3) steps == 1

This rule set recognizes the redundant scatterND by checking if the
following requirements are met:
(1) indices has the same length as the first dim of input
(2) indices is from 0 to input.shape[0]
(3)input has the same shape as updates


Benchmark on ghostnet_100 (the original speed up was 0.0256):
|| Stat       | Speedup | Increase  | Med    |
|-------|---------------|------------|-----------|-----------|
|      Suite  | Model Name  | onnx_dynamo | onnx_dynamo | onnx_dynamo |
| Timm  | ghostnet_100  | 1.1599     | 15.986%   | 1.1580    |
  • Loading branch information
titaiwangms authored Nov 1, 2024
1 parent c13e4fd commit 1ceb85b
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 1 deletion.
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from onnxscript.rewriter import (
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
gemm_to_matmul_add,
no_op,
)
Expand All @@ -21,6 +22,7 @@
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule,
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
]


Expand Down
140 changes: 140 additions & 0 deletions onnxscript/rewriter/collapse_slices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import logging

from onnxscript import ir
from onnxscript.rewriter import pattern

logger = logging.getLogger(__name__)
_INT64_MAX = 9223372036854775807


def _check_if_redundant_slice(
context,
data: ir.Value,
starts: ir.Value,
ends: ir.Value,
axes: ir.Value,
steps: ir.Value,
**_,
) -> bool:
"""If the starts is 0, and the ends is equal to or grater than the shape of the specified axis, then the slice is redundant."""
del context # Reserved for future extensions

starts_const = starts.const_value
ends_const = ends.const_value
axes_const = axes.const_value
steps_const = steps.const_value

# Check if the values are scalar
if starts_const.numpy().size != 1: # type: ignore[union-attr]
logger.info("The value 'start' is not a scalar.")
return False
if ends_const.numpy().size != 1: # type: ignore[union-attr]
logger.info("The value 'end' is not a scalar.")
return False
if axes_const.numpy().size != 1: # type: ignore[union-attr]
logger.info("The value 'axis' is not a scalar.")
return False
if steps_const.numpy().size != 1: # type: ignore[union-attr]
logger.info("The value 'step' is not a scalar.")
return False

if starts_const is None or ends_const is None or axes_const is None or steps_const is None:
logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.")
return False
if steps_const.numpy().item() != 1:
logger.info("The value 'step' is not 1.")
return False
# starts is 0
if starts_const.numpy().item() != 0:
logger.info("The value 'start' is not 0.")
return False
# In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize
if ends_const.numpy().item() == _INT64_MAX:
return True
if data.shape is None:
logger.info("The value 'data' shape is not statically known.")
return False
if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]:
logger.info("The value 'end' is less than the shape of the specified axis.")
return False

return True


def _identity_to_itself(op, data, **_):
"""Return the input data as the output."""
return op.Identity(data)


def _identity_to_updates(op, data, indices, updates, **_):
"""Return the updates as the output.
This is used when the ScatterND is redundant in terms of
updating the whole data with the updates.
"""
return op.Identity(updates)


def _potential_redundant_slice(op, data, starts, ends, axes, steps):
"""To identify a slice op"""
return op.Slice(data, starts, ends, axes, steps)


def _potential_redundant_scatternd(op, data, indices, updates):
"""To identify a ScatterND op"""
return op.ScatterND(data, indices, updates)


def _check_if_redundant_scatternd(
context,
data: ir.Value,
indices: ir.Value,
updates: ir.Value,
**_,
):
"""If the indices is the same length as the first dim of data, and the shape of updates is equal to data, we can simply swap the whole value."""
del context # Reserved for future extensions

# To validate data can be replaced directly by updates, we need to check the following:
# 1. they have the same shape
if data.shape is None:
logger.info("The value 'data' shape is not statically known.")
return False
if updates.shape is None:
logger.info("The value 'updates' shape is not statically known.")
return False
if data.shape != updates.shape:
logger.info("The shape of 'data' and 'updates' are different.")
return False

# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
if indices.const_value is None:
logger.info("The value 'indices' is not statically known.")
return False
if indices.const_value.numpy().tolist() != [[i] for i in range(data.shape[0])]: # type: ignore[arg-type]
logger.info("The 'indices' is not referring to the whole data.")
return False

return True


# Register the rewrite rules
remove_redundant_slice = pattern.RewriteRule(
_potential_redundant_slice,
_identity_to_itself,
_check_if_redundant_slice,
)

remove_redundant_scatternd = pattern.RewriteRule(
_potential_redundant_scatternd,
_identity_to_updates,
_check_if_redundant_scatternd,
)

# NOTE: The order of the rules is important. Larger pattern should be checked first.
rules = pattern.RewriteRuleSet([remove_redundant_slice, remove_redundant_scatternd])
98 changes: 98 additions & 0 deletions onnxscript/rewriter/collapse_slices_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np
import onnx.parser
import onnx.shape_inference

from onnxscript import ir
from onnxscript.rewriter import collapse_slices, testing

_INT64_MAX = 9223372036854775807


class TwoReshapesMatMulReshapeTest(unittest.TestCase):
def test_slice_is_redundant_when_ends_is_greater_than_input_shape(self):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[512, 16, 112] data) => (float[512, 16, 112] output)
{
starts = Constant<value: tensor = int64[1] {0}>()
ends = Constant<value: tensor = int64[1] {9999}>()
axes = Constant<value: tensor = int64[1] {0}>()
steps = Constant<value: tensor = int64[1] {1}>()
output = Slice (data, starts, ends, axes, steps)
}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 5)
self.assertIn("Identity", [node.op_type for node in model.graph])
testing.assert_numerically_equal(
model_proto,
model,
(np.random.rand(512, 16, 112).astype(np.float32),),
)

def test_slice_is_redundant_when_ends_reaches_int64_max(self):
model_proto = onnx.parser.parse_model(
f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[512, 16, 112] data) => (float[512, 16, 112] output)
{{
starts = Constant<value: tensor = int64[1] {{0}}>()
ends = Constant<value: tensor = int64[1] {{{_INT64_MAX}}}>()
axes = Constant<value: tensor = int64[1] {{0}}>()
steps = Constant<value: tensor = int64[1] {{1}}>()
output = Slice (data, starts, ends, axes, steps)
}}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 5)
self.assertIn("Identity", [node.op_type for node in model.graph])
testing.assert_numerically_equal(
model_proto,
model,
(np.random.rand(512, 16, 112).astype(np.float32),),
)

def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(self):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[112, 16, 512] data, float[112, 16, 512] updates) => (float[112, 16, 512] output)
{
output = ScatterND (data, indices, updates)
}
"""
)
# Use inserted initializers to avoid manually coding the large constants
indices = np.arange(112).reshape(112, 1)
model = ir.serde.deserialize_model(model_proto)
# from numpy to ir.Tensor
indices_ir_tensor = ir.Tensor(
name="indices",
value=indices,
)
# assign the tensor to a value
indices = model.graph[0].inputs[1]
indices.const_value = indices_ir_tensor
model.graph.initializers["indices"] = indices
original_model_proto = ir.serde.serialize_model(model)

count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 1)
self.assertIn("Identity", [node.op_type for node in model.graph])

input = np.random.rand(112, 16, 512).astype(np.float32)
testing.assert_numerically_equal(original_model_proto, model, (input, input))
2 changes: 1 addition & 1 deletion onnxscript/rewriter/no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def dropout_inference(op, x):


# Replacement
def identity(op, x):
def identity(op, x, **_):
return op.Identity(x)


Expand Down
76 changes: 76 additions & 0 deletions onnxscript/rewriter/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from typing import Any

import numpy as np
import onnx
import onnxruntime as ort

from onnxscript import ir


def assert_numerically_equal(
original_model_proto: onnx.ModelProto | ir.Model,
rewritten_model_proto: onnx.ModelProto | ir.Model,
args: tuple[Any, ...],
rtol: float = 1,
atol: float = 1e-3,
):
"""Assert that the two models are numerically equal.
Args:
original_model_proto: The original model proto or ir.Model.
rewritten_model_proto: The rewritten by the rules model proto or ir.Model.
rtol: Relative tolerance.
atol: Absolute tolerance.
args: The positional arguments to pass to the model.
"""

if isinstance(original_model_proto, ir.Model):
original_model_proto = ir.serde.serialize_model(original_model_proto)
if isinstance(rewritten_model_proto, ir.Model):
rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto)

original_proto_ort_inputs = {
k.name: v for k, v in zip(original_model_proto.graph.input, args)
}
original_proto_ort_inference_session = _ort_session_initializer(
original_model_proto.SerializeToString()
)
run_options = ort.RunOptions()
run_options.log_severity_level = 3 # 3: Error
original_outputs = original_proto_ort_inference_session.run(
None, original_proto_ort_inputs, run_options=run_options
)

the_rewritten_proto_ort_inputs = {
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
}
the_rewritten_proto_ort_inference_session = _ort_session_initializer(
rewritten_model_proto.SerializeToString()
)
the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run(
None, the_rewritten_proto_ort_inputs, run_options=run_options
)

np.testing.assert_allclose(
original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True
)


def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession:
"""Initialize an ONNX Runtime inference session with the specified model."""
import onnxruntime as ort

session_options = ort.SessionOptions()
session_options.log_severity_level = 3 # 3: Error
possible_providers = (
"CUDAExecutionProvider",
"CPUExecutionProvider",
)
available_providers = set(ort.get_available_providers())
providers = [
provider for provider in possible_providers if provider in available_providers
]
return ort.InferenceSession(model, providers=providers, sess_options=session_options)

0 comments on commit 1ceb85b

Please sign in to comment.