diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 664401fc709a4..42d4cc2be6fcf 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -28,18 +28,13 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider, execution_provider_(execution_provider) { } -static bool GetShapeInfo(Node& node, int64_t& start, int64_t& end, InlinedVector& dim_values) { +static bool GetShapeValues(Node& node, InlinedVector& dim_values) { auto shape = node.InputDefs()[0]->Shape(); if (!shape) return false; - dim_values.clear(); - for (int dim_index = 0; dim_index < shape->dim_size(); dim_index++) { - auto dim = shape->dim(dim_index); - dim_values.emplace_back(utils::HasDimValue(dim) ? dim.dim_value() : -1); - } - int64_t rank = static_cast(dim_values.size()); - start = 0; - end = std::numeric_limits::max(); + int64_t rank = static_cast(shape->dim_size()); + int64_t start = 0; + int64_t end = std::numeric_limits::max(); const auto& shape_attributes = node.GetAttributes(); for (const auto& attr : shape_attributes) { if (attr.first == "start") { @@ -53,31 +48,41 @@ static bool GetShapeInfo(Node& node, int64_t& start, int64_t& end, InlinedVector end = end < 0 ? end + rank : end; end = end < 0 ? 0 : ((end > rank) ? rank : end); if (end < start) end = start; + + dim_values.clear(); + for (; start < end; ++start) { + auto dim = shape->dim(static_cast(start)); + dim_values.emplace_back(utils::HasDimValue(dim) ? dim.dim_value() : -1); + } + return true; } +static void AddIntInitializerToGraph(Graph& graph, const InlinedVector& dims, + const InlinedVector& values, NodeArg* output_arg) { + ONNX_NAMESPACE::TensorProto const_tensor_proto; + const_tensor_proto.set_name(output_arg->Name()); + const_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + const_tensor_proto.set_raw_data(values.data(), values.size() * sizeof(int64_t)); + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (auto dim : dims) { + const_tensor_proto.add_dims(dim); + output_shape.add_dim()->set_dim_value(dim); + } + output_arg->SetShape(output_shape); + graph.AddInitializedTensor(const_tensor_proto); +} + // We need to handle a Shape node separately as the input doesn't need to be a constant initializer for // Shape to be able to be constant folded. static bool ConstantFoldShapeNode(Graph& graph, Node& node) { - int64_t start = 0; - int64_t end = std::numeric_limits::max(); InlinedVector dim_values; - if (!GetShapeInfo(node, start, end, dim_values) || - std::any_of(dim_values.cbegin() + start, dim_values.cbegin() + end, [](int64_t dim) { return dim == -1; })) { + if (!GetShapeValues(node, dim_values) || + std::any_of(dim_values.cbegin(), dim_values.cend(), [](int64_t dim) { return dim == -1; })) { return false; } - - size_t slice_length = static_cast(end - start); - ONNX_NAMESPACE::TensorProto shape_constant; - auto* constant_arg_out = node.MutableOutputDefs()[0]; - shape_constant.set_name(constant_arg_out->Name()); - shape_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - shape_constant.add_dims(slice_length); - shape_constant.set_raw_data(dim_values.data() + start, slice_length * sizeof(int64_t)); - ONNX_NAMESPACE::TensorShapeProto result_shape; - result_shape.add_dim()->set_dim_value(slice_length); - constant_arg_out->SetShape(result_shape); - graph.AddInitializedTensor(shape_constant); + AddIntInitializerToGraph(graph, InlinedVector{static_cast(dim_values.size())}, dim_values, + node.MutableOutputDefs()[0]); return true; } @@ -88,46 +93,35 @@ static bool ConstantFoldGatherNode(Graph& graph, Node& node) { if (node.GetInputEdgesCount() == 0) return false; Node& pre_node = *graph.GetNode(node.InputNodesBegin()->Index()); if (pre_node.OpType() != "Shape" || pre_node.OutputDefs()[0]->Name() != node.InputDefs()[0]->Name()) return false; - int64_t start = 0; - int64_t end = std::numeric_limits::max(); InlinedVector dim_values; - if (!GetShapeInfo(pre_node, start, end, dim_values)) return false; + if (!GetShapeValues(pre_node, dim_values)) return false; // Get indices input. const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); if (!tensor_proto || !utils::HasDataType(*tensor_proto) || - tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64 || tensor_proto->dims_size() >= 2) { + tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) { return false; } size_t output_element_count = 1; - if (tensor_proto->dims_size() > 0) output_element_count = static_cast(tensor_proto->dims()[0]); - if (output_element_count == 0) return false; + InlinedVector output_dims; + for (int i = 0; i < tensor_proto->dims_size(); ++i) { + int64_t dim = tensor_proto->dims()[i]; + output_dims.emplace_back(dim); + output_element_count *= static_cast(dim); + } InlinedVector output_values; Initializer init_const{*tensor_proto, graph.ModelPath()}; const int64_t* indices_data = init_const.data(); - int64_t sliced_rank = end - start; + int64_t rank = static_cast(dim_values.size()); for (size_t i = 0; i < output_element_count; ++i) { int64_t index = indices_data[i]; - if (index < 0) index += sliced_rank; - if (index < 0 || index >= sliced_rank || dim_values[static_cast(index + start)] == -1) return false; - output_values.push_back(dim_values[static_cast(index + start)]); + if (index < 0) index += rank; + if (index < 0 || index >= rank || dim_values[static_cast(index)] == -1) return false; + output_values.emplace_back(dim_values[static_cast(index)]); } - ONNX_NAMESPACE::TensorProto gather_output_constant; - auto* gather_output_arg = node.MutableOutputDefs()[0]; - gather_output_constant.set_name(gather_output_arg->Name()); - gather_output_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - if (tensor_proto->dims_size() == 1) { - gather_output_constant.add_dims(static_cast(output_element_count)); - } - gather_output_constant.set_raw_data(output_values.data(), output_element_count * sizeof(int64_t)); - ONNX_NAMESPACE::TensorShapeProto gather_output_shape; - if (tensor_proto->dims_size() == 1) { - gather_output_shape.add_dim()->set_dim_value(static_cast(output_element_count)); - } - gather_output_arg->SetShape(gather_output_shape); - graph.AddInitializedTensor(gather_output_constant); + AddIntInitializerToGraph(graph, output_dims, output_values, node.MutableOutputDefs()[0]); return true; }