Skip to content

Commit

Permalink
update test and format
Browse files Browse the repository at this point in the history
  • Loading branch information
ruiren committed Feb 7, 2024
1 parent abf2994 commit 7870789
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
3 changes: 1 addition & 2 deletions onnxruntime/core/optimizer/gather_slice_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,8 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
for (int64_t i = 0; i < rank; i++) {
if (i == split_axis)
split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL);
else {
else
*(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast<int>(i));
}
}

InlinedVector<NodeArg*> split_output_types;
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7718,7 +7718,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) {
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
};
}
}

TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) {
Expand Down Expand Up @@ -7763,7 +7763,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) {
};

auto pre_graph_checker = [&](Graph& graph) {
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1);
TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1);
return Status::OK();
};
Expand All @@ -7778,7 +7778,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) {
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherSliceToSplitFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer),
TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker));
};
}
}

} // namespace test
Expand Down

0 comments on commit 7870789

Please sign in to comment.