diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 4b56bc1e8d828..4b029f9b172b0 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1940,8 +1940,17 @@ def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802 def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802 input_sympy_shape = self._get_sympy_shape(node, 0) axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) - split = get_attribute(node, "split") - if not split: + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'split' are provided as attribute or via 2nd input + if op_set < 13: + split = get_attribute(node, "split") + assert self._try_get_value(node, 1) is None + else: + split = self._try_get_value(node, 1) + assert get_attribute(node, "split") is None + + if split is None: num_outputs = len(node.output) split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs self._update_computed_dims(split)