Skip to content

Commit

Permalink
Zhijxu/fix conv1d replacement (#19758)
Browse files Browse the repository at this point in the history
remove the constraint - "group number should be less than 3";
add more condition to make sure the conv1d replacement only happens on
conv1d instead of conv2d/conv3d;
add more tests;
  • Loading branch information
zhijxu-MS authored Mar 5, 2024
1 parent 0cdf36f commit 2a5c9b8
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 31 deletions.
63 changes: 40 additions & 23 deletions orttraining/orttraining/core/optimizer/conv1d_replacement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,45 @@
*/
namespace onnxruntime {
bool NodeCanBeReplacedByMatmul(const Node& node) {
// If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2,
// then it can be replaced by MatMul
// Kernel_shape is 1 means it is conv1d
/*
If node type is Conv, and satisfy the following conditions then it can be replaced by MatMul:
- not bias as input which means only has 2 inputs: input and weight
- "dilations" should be [1]
size 1 means conv1d
- "strides" should be [1]
- "pads" should be [0,0]
- "autopad" should be "NOTSET"
- "kernel_shape" should be [1]
*/
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) {
return false;
}
const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations");
const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape");
const auto* stride = graph_utils::GetNodeAttribute(node, "strides");
const auto* group = graph_utils::GetNodeAttribute(node, "group");
if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) {

// TODO: bias input can also be supported if needed
if (node.InputDefs().size() != 2) {
return false;
}
if ((dilations->ints_size() && dilations->ints(0) != 1) ||
(kernel_shape->ints_size() && kernel_shape->ints(0) != 1) ||
(stride->ints_size() && stride->ints(0) != 1) ||
group->i() >= 3) {

const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations");
const auto* strides = graph_utils::GetNodeAttribute(node, "strides");
const auto* pads = graph_utils::GetNodeAttribute(node, "pads");
const auto* autopad = graph_utils::GetNodeAttribute(node, "auto_pad");
const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape");
if (dilations == nullptr || strides == nullptr || pads == nullptr || autopad == nullptr || kernel_shape == nullptr) {
return false;
}

return true;
if ((dilations->ints_size() == 1 && dilations->ints(0) == 1) &&
(strides->ints_size() == 1 && strides->ints(0) == 1) &&
(autopad->s() == "NOTSET") &&
(pads->ints_size() == 2 && pads->ints(0) == 0 && pads->ints(1) == 0) &&
(kernel_shape->ints_size() == 1 && kernel_shape->ints(0) == 1)) {
return true;
}
return false;
}

void Conv1dToMatmul(Graph& graph, Node& conv) {
void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name) {
// Shape of conv1d input: [batch_size, in_channels, in_length]
// Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1
// We need to split the input into "group", and squeeze&split the weight, and then do MatMul
Expand All @@ -83,7 +98,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg(
graph.GenerateNodeArgName("input_split_output"), nullptr));
}
auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input},
auto& input_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {conv1d_input},
{conv1d_input_splitted_outputs});
input_split.SetExecutionProviderType(execution_provider_type);
input_split.AddAttribute("axis", int64_t(1));
Expand All @@ -93,23 +108,25 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
}
// 2. Squeeze conv weight
auto conv1d_weight = conv.MutableInputDefs()[1];
// auto con1d_bias = xx;
auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr);
auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze",
auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName(transformer_name + "WeightSqueeze"), "Squeeze",
node_description, {conv1d_weight}, {weight_squeeze_output});
int64_t weight_squeeze_axis = 2;
if (onnx_opset_version > 12) {
// After onnx version 12, squeeze node has axes as input instead of attribute
ONNX_NAMESPACE::TensorProto initializer_proto;
initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer"));
initializer_proto.set_name(graph.GenerateNodeName(transformer_name + "ConstAsInitializer"));
initializer_proto.add_dims(static_cast<int64_t>(1));
initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
InlinedVector<int64_t> initializer_proto_value{2};
InlinedVector<int64_t> initializer_proto_value{weight_squeeze_axis};
initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t));
auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto);
// Squeeze node doesn't have opschema here, so we need to set input args count manually
weight_squeeze.MutableInputArgsCount().resize(2);
graph_utils::AddNodeInput(weight_squeeze, 1, axes_input);
} else {
weight_squeeze.AddAttribute("axes", std::vector<int64_t>{2});
weight_squeeze.AddAttribute("axes", std::vector<int64_t>{weight_squeeze_axis});
}
weight_squeeze.SetExecutionProviderType(execution_provider_type);
// 3. Split conv weight
Expand All @@ -118,7 +135,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg(
graph.GenerateNodeArgName("weight_split_output"), nullptr));
}
auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description,
auto& weight_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description,
{weight_squeeze_output}, {conv1d_weight_splitted_outputs});
weight_split.AddAttribute("axis", int64_t(0));
weight_split.SetExecutionProviderType(execution_provider_type);
Expand All @@ -130,13 +147,13 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
for (int i = 0; i < group_num; i++) {
auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr);
matmul_outputs.push_back(matmul_output);
auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description,
auto& matmul = graph.AddNode(graph.GenerateNodeName(transformer_name + "Matmul"), "MatMul", node_description,
{conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]},
{matmul_output});
matmul.SetExecutionProviderType(execution_provider_type);
}
// 5. Concat matmul outputs
auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description,
auto& concat_node = graph.AddNode(graph.GenerateNodeName(transformer_name + "Concat"), "Concat", node_description,
matmul_outputs, {});
concat_node.SetExecutionProviderType(execution_provider_type);
concat_node.AddAttribute("axis", int64_t(1));
Expand All @@ -155,7 +172,7 @@ Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_leve
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (NodeCanBeReplacedByMatmul(node)) {
LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name();
Conv1dToMatmul(graph, node);
Conv1dToMatmul(graph, node, Name());
modified = true;
}
}
Expand Down
64 changes: 56 additions & 8 deletions orttraining/orttraining/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1200,15 +1200,15 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) {
ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1);
}

TEST_F(GraphTransformationTests, Conv1dReplacement) {
TEST_F(GraphTransformationTests, Conv1dReplacement_TakeEffect) {
auto pre_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
return Status::OK();
};

for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
for (auto group : {1, 2}) {
for (auto group : {1, 2, 4}) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
auto out_channel = 64;
Expand All @@ -1222,6 +1222,8 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) {
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
conv_node.AddAttribute("strides", std::vector<int64_t>{1});
conv_node.AddAttribute("group", static_cast<int64_t>(group));
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0});
conv_node.AddAttribute("auto_pad", "NOTSET");
};

auto post_graph_checker = [&](Graph& graph) {
Expand All @@ -1243,37 +1245,81 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) {
}
}

TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) {
// node has bias input so conv not replaced
TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect1) {
auto pre_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
return Status::OK();
};

// "group" is 3 so conv not replaced
for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
auto out_channel = 64;
auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});

auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f});
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, {-1.0f, 1.0f});
auto* bias_arg = builder.MakeInitializer<float>({out_channel}, {-1.0f, 1.0f});
auto* conv_output = builder.MakeOutput();

auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg, bias_arg}, {conv_output});
conv_node.AddAttribute("dilations", std::vector<int64_t>{1});
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
conv_node.AddAttribute("strides", std::vector<int64_t>{1});
conv_node.AddAttribute("group", static_cast<int64_t>(1));
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0});
conv_node.AddAttribute("auto_pad", "NOTSET");
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<Conv1dReplacement>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer),
TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker));
}
}

// "auto_pad " is not NOTSET so conv not replaced
TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect2) {
auto pre_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
return Status::OK();
};

for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
auto out_channel = 64;
auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});

auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, {-1.0f, 1.0f});
auto* conv_output = builder.MakeOutput();

auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});
conv_node.AddAttribute("dilations", std::vector<int64_t>{1});
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
conv_node.AddAttribute("strides", std::vector<int64_t>{1});
conv_node.AddAttribute("group", static_cast<int64_t>(3));
conv_node.AddAttribute("group", static_cast<int64_t>(1));
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0});
conv_node.AddAttribute("auto_pad", "VALID");
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<Conv1dReplacement>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer),
TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker));
}
}

// pads is not all zero, so conv not replaced
TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect3) {
auto pre_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
return Status::OK();
};

// "kernel_shape" is not 1 so conv not replaced
for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
Expand All @@ -1285,9 +1331,11 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) {

auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});
conv_node.AddAttribute("dilations", std::vector<int64_t>{1});
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{2});
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
conv_node.AddAttribute("strides", std::vector<int64_t>{1});
conv_node.AddAttribute("group", static_cast<int64_t>(1));
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 0});
conv_node.AddAttribute("auto_pad", "NOTSET");
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<Conv1dReplacement>();
Expand Down

0 comments on commit 2a5c9b8

Please sign in to comment.