From f24d5fb3ee29cc9c7b2d622531a9769eb0d4d4a0 Mon Sep 17 00:00:00 2001 From: inisis Date: Mon, 11 Mar 2024 17:32:16 +0800 Subject: [PATCH 1/2] fix shape inference bug --- .../python/tools/symbolic_shape_infer.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 4b029f9b172b0..ca0802635e963 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -498,6 +498,26 @@ def _onnx_infer_single_node(self, node): if (name in self.initializers_ and name not in self.graph_inputs_) ] + if node.op_type in [ + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Where", + "Sum", + ]: + if node.output[0] in self.known_vi_: + vi = self.known_vi_[node.output[0]] + out_rank = len(get_shape_from_type_proto(vi.type)) + in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] + for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): + in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] + if len(in_dims) > 1: + self._check_merged_dims(in_dims, allow_broadcast=True) + # run single node inference with self.known_vi_ shapes tmp_graph = helper.make_graph( [node], From f19ce16cd3fd1061341b8f0a9ce2c70c6fc07e87 Mon Sep 17 00:00:00 2001 From: inisis Date: Tue, 12 Mar 2024 10:09:41 +0800 Subject: [PATCH 2/2] fix lint --- onnxruntime/python/tools/symbolic_shape_infer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ca0802635e963..de2583bbc300f 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -513,7 +513,9 @@ def _onnx_infer_single_node(self, node): vi = self.known_vi_[node.output[0]] out_rank = len(get_shape_from_type_proto(vi.type)) in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] - for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): + for d in range( + out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0) + ): in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] if len(in_dims) > 1: self._check_merged_dims(in_dims, allow_broadcast=True)