diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/collapse_slices.py index 2615432e7..689557af1 100644 --- a/onnxscript/rewriter/collapse_slices.py +++ b/onnxscript/rewriter/collapse_slices.py @@ -56,7 +56,7 @@ def _check_if_redundant_slice( # In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize if ends_const.numpy().item() == _INT64_MAX: return True - if data.shape is None: + if data.shape is None or data.shape.is_dynamic(axes_const.numpy().item()): logger.info("The value 'data' shape is not statically known.") return False if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]: diff --git a/onnxscript/rewriter/collapse_slices_test.py b/onnxscript/rewriter/collapse_slices_test.py index 8632f61ca..6a11bd202 100644 --- a/onnxscript/rewriter/collapse_slices_test.py +++ b/onnxscript/rewriter/collapse_slices_test.py @@ -65,6 +65,24 @@ def test_slice_is_redundant_when_ends_reaches_int64_max(self): (np.random.rand(512, 16, 112).astype(np.float32),), ) + def test_slice_pattern_is_not_matched_when_input_is_dynamic(self): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float[L, M, N] data) => (float[L, M, N] output) + {{ + starts = Constant() + ends = Constant() + axes = Constant() + steps = Constant() + output = Slice (data, starts, ends, axes, steps) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = collapse_slices.rules.apply_to_model(model) + self.assertEqual(count, 0) + def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(self): model_proto = onnx.parser.parse_model( """