Skip to content

Commit

Permalink
[rewriter] Fix slices pattern (#1949)
Browse files Browse the repository at this point in the history
The pattern did not cover the dynamic shapes case, so it leads to "'<'
not supported between instances of 'int' and 'SymbolicDim'" when input
is dynamic.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
titaiwangms and justinchuby authored Nov 16, 2024
1 parent 88dca66 commit bd4233b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion onnxscript/rewriter/collapse_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _check_if_redundant_slice(
# In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize
if ends_const.numpy().item() == _INT64_MAX:
return True
if data.shape is None:
if data.shape is None or data.shape.is_dynamic(axes_const.numpy().item()):
logger.info("The value 'data' shape is not statically known.")
return False
if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]:
Expand Down
18 changes: 18 additions & 0 deletions onnxscript/rewriter/collapse_slices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def test_slice_is_redundant_when_ends_reaches_int64_max(self):
(np.random.rand(512, 16, 112).astype(np.float32),),
)

def test_slice_pattern_is_not_matched_when_input_is_dynamic(self):
model_proto = onnx.parser.parse_model(
f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[L, M, N] data) => (float[L, M, N] output)
{{
starts = Constant<value: tensor = int64[1] {{0}}>()
ends = Constant<value: tensor = int64[1] {{{9}}}>()
axes = Constant<value: tensor = int64[1] {{0}}>()
steps = Constant<value: tensor = int64[1] {{1}}>()
output = Slice (data, starts, ends, axes, steps)
}}
"""
)
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 0)

def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(self):
model_proto = onnx.parser.parse_model(
"""
Expand Down

0 comments on commit bd4233b

Please sign in to comment.