Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
centwang committed Mar 10, 2024
1 parent e94effd commit 4b19595
Showing 1 changed file with 43 additions and 49 deletions.
92 changes: 43 additions & 49 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,13 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider,
execution_provider_(execution_provider) {
}

static bool GetShapeInfo(Node& node, int64_t& start, int64_t& end, InlinedVector<int64_t>& dim_values) {
static bool GetShapeValues(Node& node, InlinedVector<int64_t>& dim_values) {
auto shape = node.InputDefs()[0]->Shape();
if (!shape) return false;
dim_values.clear();
for (int dim_index = 0; dim_index < shape->dim_size(); dim_index++) {
auto dim = shape->dim(dim_index);
dim_values.emplace_back(utils::HasDimValue(dim) ? dim.dim_value() : -1);
}

int64_t rank = static_cast<int64_t>(dim_values.size());
start = 0;
end = std::numeric_limits<int64_t>::max();
int64_t rank = static_cast<int64_t>(shape->dim_size());
int64_t start = 0;
int64_t end = std::numeric_limits<int64_t>::max();
const auto& shape_attributes = node.GetAttributes();
for (const auto& attr : shape_attributes) {
if (attr.first == "start") {
Expand All @@ -53,31 +48,41 @@ static bool GetShapeInfo(Node& node, int64_t& start, int64_t& end, InlinedVector
end = end < 0 ? end + rank : end;
end = end < 0 ? 0 : ((end > rank) ? rank : end);
if (end < start) end = start;

dim_values.clear();
for (; start < end; ++start) {
auto dim = shape->dim(static_cast<int>(start));
dim_values.emplace_back(utils::HasDimValue(dim) ? dim.dim_value() : -1);
}

return true;
}

static void AddIntInitializerToGraph(Graph& graph, const InlinedVector<int64_t>& dims,
const InlinedVector<int64_t>& values, NodeArg* output_arg) {
ONNX_NAMESPACE::TensorProto const_tensor_proto;
const_tensor_proto.set_name(output_arg->Name());
const_tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
const_tensor_proto.set_raw_data(values.data(), values.size() * sizeof(int64_t));
ONNX_NAMESPACE::TensorShapeProto output_shape;
for (auto dim : dims) {
const_tensor_proto.add_dims(dim);
output_shape.add_dim()->set_dim_value(dim);
}
output_arg->SetShape(output_shape);
graph.AddInitializedTensor(const_tensor_proto);
}

// We need to handle a Shape node separately as the input doesn't need to be a constant initializer for
// Shape to be able to be constant folded.
static bool ConstantFoldShapeNode(Graph& graph, Node& node) {
int64_t start = 0;
int64_t end = std::numeric_limits<int64_t>::max();
InlinedVector<int64_t> dim_values;
if (!GetShapeInfo(node, start, end, dim_values) ||
std::any_of(dim_values.cbegin() + start, dim_values.cbegin() + end, [](int64_t dim) { return dim == -1; })) {
if (!GetShapeValues(node, dim_values) ||
std::any_of(dim_values.cbegin(), dim_values.cend(), [](int64_t dim) { return dim == -1; })) {
return false;
}

size_t slice_length = static_cast<size_t>(end - start);
ONNX_NAMESPACE::TensorProto shape_constant;
auto* constant_arg_out = node.MutableOutputDefs()[0];
shape_constant.set_name(constant_arg_out->Name());
shape_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_constant.add_dims(slice_length);
shape_constant.set_raw_data(dim_values.data() + start, slice_length * sizeof(int64_t));
ONNX_NAMESPACE::TensorShapeProto result_shape;
result_shape.add_dim()->set_dim_value(slice_length);
constant_arg_out->SetShape(result_shape);
graph.AddInitializedTensor(shape_constant);
AddIntInitializerToGraph(graph, InlinedVector<int64_t>{static_cast<int64_t>(dim_values.size())}, dim_values,
node.MutableOutputDefs()[0]);
return true;
}

Expand All @@ -88,46 +93,35 @@ static bool ConstantFoldGatherNode(Graph& graph, Node& node) {
if (node.GetInputEdgesCount() == 0) return false;
Node& pre_node = *graph.GetNode(node.InputNodesBegin()->Index());
if (pre_node.OpType() != "Shape" || pre_node.OutputDefs()[0]->Name() != node.InputDefs()[0]->Name()) return false;
int64_t start = 0;
int64_t end = std::numeric_limits<int64_t>::max();
InlinedVector<int64_t> dim_values;
if (!GetShapeInfo(pre_node, start, end, dim_values)) return false;
if (!GetShapeValues(pre_node, dim_values)) return false;

// Get indices input.
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name());
if (!tensor_proto || !utils::HasDataType(*tensor_proto) ||
tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64 || tensor_proto->dims_size() >= 2) {
tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) {
return false;
}
size_t output_element_count = 1;
if (tensor_proto->dims_size() > 0) output_element_count = static_cast<size_t>(tensor_proto->dims()[0]);
if (output_element_count == 0) return false;
InlinedVector<int64_t> output_dims;
for (int i = 0; i < tensor_proto->dims_size(); ++i) {
int64_t dim = tensor_proto->dims()[i];
output_dims.emplace_back(dim);
output_element_count *= static_cast<size_t>(dim);
}
InlinedVector<int64_t> output_values;
Initializer init_const{*tensor_proto, graph.ModelPath()};
const int64_t* indices_data = init_const.data<int64_t>();
int64_t sliced_rank = end - start;
int64_t rank = static_cast<int64_t>(dim_values.size());
for (size_t i = 0; i < output_element_count; ++i) {
int64_t index = indices_data[i];
if (index < 0) index += sliced_rank;
if (index < 0 || index >= sliced_rank || dim_values[static_cast<size_t>(index + start)] == -1) return false;
output_values.push_back(dim_values[static_cast<size_t>(index + start)]);
if (index < 0) index += rank;
if (index < 0 || index >= rank || dim_values[static_cast<size_t>(index)] == -1) return false;
output_values.emplace_back(dim_values[static_cast<size_t>(index)]);
}

ONNX_NAMESPACE::TensorProto gather_output_constant;
auto* gather_output_arg = node.MutableOutputDefs()[0];
gather_output_constant.set_name(gather_output_arg->Name());
gather_output_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
if (tensor_proto->dims_size() == 1) {
gather_output_constant.add_dims(static_cast<int64_t>(output_element_count));
}
gather_output_constant.set_raw_data(output_values.data(), output_element_count * sizeof(int64_t));
ONNX_NAMESPACE::TensorShapeProto gather_output_shape;
if (tensor_proto->dims_size() == 1) {
gather_output_shape.add_dim()->set_dim_value(static_cast<int64_t>(output_element_count));
}
gather_output_arg->SetShape(gather_output_shape);
graph.AddInitializedTensor(gather_output_constant);
AddIntInitializerToGraph(graph, output_dims, output_values, node.MutableOutputDefs()[0]);
return true;
}

Expand Down

0 comments on commit 4b19595

Please sign in to comment.