Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ruiren committed Jan 22, 2024
1 parent 9d0b66e commit 2edb4b3
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions onnxruntime/core/optimizer/gather_slice_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ namespace onnxruntime {

bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index,
int64_t& axis, int64_t& indices_n_dims) const {

if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return false;
Expand Down Expand Up @@ -60,7 +59,7 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node&
int onnx_opset_version = -1;
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
};
}

// If Slice op of opset version 1
if (onnx_opset_version == 1) {
Expand All @@ -73,7 +72,7 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node&
if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) {
return false;
}
};
}

// If Slice op of opset version >= 10
if (onnx_opset_version >= 10) {
Expand Down Expand Up @@ -128,7 +127,6 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node&
return false;
}

// TODO: what does this mean ?
if (axes_init->dims_size() != 1 || static_cast<size_t>(axes_init->dims().Get(0)) != starts.size()) {
return false;
}
Expand Down Expand Up @@ -210,7 +208,6 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
int64_t split_axis = 0;
int64_t indices_n_dims = -1;

// TODO: How to catch up the Slice output value
// 2 Gather, and 1 slice...
InlinedVector<NodeArg*> reshape_outputs;

Expand Down Expand Up @@ -298,9 +295,10 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
if (i == split_axis)
split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL);
else {

Check warning on line 297 in onnxruntime/core/optimizer/gather_slice_fusion.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: onnxruntime/core/optimizer/gather_slice_fusion.cc:297: If an else has a brace on one side, it should have it on both [readability/braces] [5]
*(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast<int>(i));
*(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim())
= shape->dim(static_cast<int>(i));
}
};
}

InlinedVector<NodeArg*> split_outputs;

Expand All @@ -319,8 +317,9 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
split_initializer_proto.add_dims(static_cast<int64_t>(1));
split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);

auto dim_value = shape->dim(static_cast<int>(split_axis))->dim_value();
InlinedVector<int64_t> split_value{{dim_value - gather_node_count, 1, 1}};
auto dim_value = shape->dim(static_cast<int>(split_axis)).dim_value();
int64_t slice_dim = static_cast<int64_t>(dim_value - gather_node_count);
InlinedVector<int64_t> split_value{{slice_dim, 1, 1}};
split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t));
NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);

Expand Down

0 comments on commit 2edb4b3

Please sign in to comment.