diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 8a911071864aa..040bf5ae76fff 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -495,6 +495,28 @@ def _onnx_infer_single_node(self, node): if (name in self.initializers_ and name not in self.graph_inputs_) ] + if node.op_type in [ + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Where", + "Sum", + ]: + if node.output[0] in self.known_vi_: + vi = self.known_vi_[node.output[0]] + out_rank = len(get_shape_from_type_proto(vi.type)) + in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] + for d in range( + out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0) + ): + in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] + if len(in_dims) > 1: + self._check_merged_dims(in_dims, allow_broadcast=True) + # run single node inference with self.known_vi_ shapes tmp_graph = helper.make_graph( [node],