Skip to content

Commit

Permalink
[XPU] support token-slice for encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
TR666 committed Mar 15, 2024
1 parent 3c61295 commit 8c22d8f
Show file tree
Hide file tree
Showing 7 changed files with 371 additions and 28 deletions.
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ USE_MIR_PASS(__xpu__spatial_transformer_resblock_fuse_pass);
USE_MIR_PASS(__xpu__matmul_scale_softmax_v1_fuse_pass);
USE_MIR_PASS(__xpu__up_decoder_fuse_pass);
USE_MIR_PASS(__xpu__multi_up_decoder_fuse_pass);
USE_MIR_PASS(__xpu__remove_mask_slice_pass);
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v2_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v3_fuse_pass);
Expand Down
196 changes: 169 additions & 27 deletions lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class XPUSingleEncoderFuser : public FuseBase {
bool norm_before = false,
const std::string& relative_type = "",
bool with_mask = true,
bool smooth_quant = false)
bool smooth_quant = false,
bool with_token_slice = false)
: act_type_(act_type),
input_pos_(input_pos),
qkv_ln_2_out_pos_(qkv_ln_2_out_pos),
Expand All @@ -66,7 +67,8 @@ class XPUSingleEncoderFuser : public FuseBase {
norm_before_(norm_before),
relative_emb_type_(relative_type),
with_mask_(with_mask),
smooth_quant_(smooth_quant) {}
smooth_quant_(smooth_quant),
with_token_slice_(with_token_slice) {}

void BuildPattern() override {
PMNode* input = nullptr;
Expand All @@ -79,11 +81,39 @@ class XPUSingleEncoderFuser : public FuseBase {
PMNode* ln_before_out = nullptr;
PMNode* ln_before_mean = nullptr;
PMNode* ln_before_var = nullptr;
PMNode* input_slice = nullptr;
PMNode* input_slice_out = nullptr;
if (smooth_quant_ && !norm_before_) {
VLOG(3) << "build first smooth_quant_scale";
input = VarNode("input")
->assert_is_op_input("elementwise_mul", "X")
->AsInput();
if (with_token_slice_) {
VLOG(3) << "build input_slice";
input =
VarNode("input")->assert_is_op_input("slice", "Input")->AsInput();
input_slice = OpNode("input_slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) {
return attr.size() == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"ends",
[](const std::vector<int>& attr) {
return attr.size() == 1;
})
->AsIntermediate();
input_slice_out = VarNode("input_slice_out")
->assert_is_op_output("slice", "Out")
->assert_is_op_input("elementwise_mul", "X");
} else {
input = VarNode("input")
->assert_is_op_input("elementwise_mul", "X")
->AsInput();
}
smooth_scale_1_weight = VarNode("smooth_scale_1_weight")
->assert_is_op_input("elementwise_mul", "Y")
->AsInput();
Expand All @@ -92,6 +122,26 @@ class XPUSingleEncoderFuser : public FuseBase {
smooth_scale_1_out = VarNode("smooth_scale_1_out")
->assert_is_op_output("elementwise_mul", "Out")
->AsIntermediate();
} else if (with_token_slice_) {
VLOG(3) << "build input_slice";
input = VarNode("input")->assert_is_op_input("slice", "Input")->AsInput();
input_slice =
OpNode("input_slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) { return attr.size() == 1; })
->assert_op_attr_satisfied<std::vector<int>>(
"ends",
[](const std::vector<int>& attr) { return attr.size() == 1; })
->AsIntermediate();
input_slice_out = VarNode("input_slice_out")
->assert_is_op_output("slice", "Out")
->assert_is_op_input("elementwise_add", input_pos_);
} else {
input = VarNode("input")
->assert_is_op_input("elementwise_add", input_pos_)
Expand Down Expand Up @@ -311,11 +361,40 @@ class XPUSingleEncoderFuser : public FuseBase {
VarNode("qkv_transpose2_xshape")
->assert_is_op_output("transpose2", "XShape")
->AsIntermediate();
PMNode* qkv_slice = nullptr;
PMNode* qkv_slice_out = nullptr;
PMNode* qkv_reshape2_out = nullptr;
auto* qkv_reshape2 = OpNode("qkv_reshape2", "reshape2")->AsIntermediate();
auto* qkv_reshape2_out = VarNode("qkv_reshape2_out")
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input(mul_type_, "X")
->AsIntermediate();
if (with_token_slice_) {
qkv_reshape2_out = VarNode("qkv_reshape2_out")
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("slice", "Input")
->AsIntermediate();
VLOG(3) << "build qkv_slice";
qkv_slice =
OpNode("qkv_slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) { return attr.size() == 1; })
->assert_op_attr_satisfied<std::vector<int>>(
"ends",
[](const std::vector<int>& attr) { return attr.size() == 1; })
->AsIntermediate();
qkv_slice_out = VarNode("qkv_slice_out")
->assert_is_op_output("slice", "Out")
->assert_is_op_input(mul_type_, "X")
->AsIntermediate();
} else {
qkv_reshape2_out = VarNode("qkv_reshape2_out")
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input(mul_type_, "X")
->AsIntermediate();
}
auto* qkv_reshape2_xshape = VarNode("qkv_reshape2_xshape")
->assert_is_op_output("reshape2", "XShape")
->AsIntermediate();
Expand Down Expand Up @@ -531,18 +610,33 @@ class XPUSingleEncoderFuser : public FuseBase {
*v_transpose2 >> *v_transpose2_xshape;

*qkv_matmul >> *qkv_matmul_out >> *qkv_transpose2 >> *qkv_transpose2_out >>
*qkv_reshape2 >> *qkv_reshape2_out >> *qkv_mul >> *qkv_mul_out >>
*qkv_add >> *qkv_add_out >> *qkv_add_2;
*qkv_reshape2 >> *qkv_reshape2_out;
if (with_token_slice_) {
*qkv_reshape2_out >> *qkv_slice >> *qkv_slice_out >> *qkv_mul;
} else {
*qkv_reshape2_out >> *qkv_mul;
}
*qkv_mul >> *qkv_mul_out >> *qkv_add >> *qkv_add_out >> *qkv_add_2;
*qkv_transpose2 >> *qkv_transpose2_xshape;
*qkv_reshape2 >> *qkv_reshape2_xshape;
*qkv_mul_y >> *qkv_mul;
*qkv_add_y >> *qkv_add;
if (smooth_quant_ && !norm_before_) {
*smooth_scale_1_weight >> *smooth_scale_1;
*input >> *smooth_scale_1 >> *smooth_scale_1_out >> *qkv_add_2 >>
*qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
if (with_token_slice_) {
*input >> *input_slice >> *input_slice_out >> *smooth_scale_1;
} else {
*input >> *smooth_scale_1;
}
*smooth_scale_1 >> *smooth_scale_1_out >> *qkv_add_2 >> *qkv_add_2_out >>
*qkv_ln_2 >> *qkv_ln_2_out;
} else {
*input >> *qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
if (with_token_slice_) {
*input >> *input_slice >> *input_slice_out >> *qkv_add_2;
} else {
*input >> *qkv_add_2;
}
*qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
}
*qkv_ln_2_scale >> *qkv_ln_2;
*qkv_ln_2_bias >> *qkv_ln_2;
Expand Down Expand Up @@ -619,6 +713,8 @@ class XPUSingleEncoderFuser : public FuseBase {
// the model is smooth_quant or not, we don't need to do anything
// so, set is_smooth_quant as false.
op_desc.SetAttr<bool>("is_smooth_quant", false);
// temporarily does not support token slice in the case of pre-layernorm
op_desc.SetAttr<bool>("with_token_slice", false);
} else {
op_desc.SetInput("LNScale",
{
Expand All @@ -643,6 +739,33 @@ class XPUSingleEncoderFuser : public FuseBase {
} else {
op_desc.SetAttr<bool>("is_smooth_quant", false);
}
if (with_token_slice_) {
op_desc.SetAttr<bool>("with_token_slice", true);
int token_sliced_length = -1;
auto* qkv_slice_op_info = matched.at("qkv_slice")->stmt()->op_info();
auto* input_slice_op_info =
matched.at("input_slice")->stmt()->op_info();
if (qkv_slice_op_info->HasAttr("starts") &&
qkv_slice_op_info->HasAttr("ends") &&
input_slice_op_info->HasAttr("starts") &&
input_slice_op_info->HasAttr("ends")) {
auto qkv_slice_starts =
qkv_slice_op_info->GetAttr<std::vector<int>>("starts");
auto qkv_slice_ends =
qkv_slice_op_info->GetAttr<std::vector<int>>("ends");
auto input_slice_starts =
input_slice_op_info->GetAttr<std::vector<int>>("starts");
auto input_slice_ends =
input_slice_op_info->GetAttr<std::vector<int>>("ends");
CHECK_EQ(qkv_slice_starts.size(), input_slice_starts.size());
CHECK_EQ(qkv_slice_ends.size(), input_slice_ends.size());
CHECK_EQ(qkv_slice_starts[0], input_slice_starts[0]);
CHECK_EQ(qkv_slice_ends[0], input_slice_ends[0]);
token_sliced_length = qkv_slice_ends[0] - qkv_slice_starts[0];
CHECK_GT(token_sliced_length, 0);
op_desc.SetAttr<int>("token_sliced_length", token_sliced_length);
}
}
}
// XXX: keep these to fool SubgraphOp::AttachImpl()
op_desc.SetAttr<int>("sub_block", 0);
Expand Down Expand Up @@ -792,6 +915,7 @@ class XPUSingleEncoderFuser : public FuseBase {
const std::string relative_emb_type_;
bool with_mask_;
bool smooth_quant_;
bool with_token_slice_;
// quant_info: mul input_max, output_max * 6 + matmul x_max:y_max,
// output_max
// * 2
Expand Down Expand Up @@ -1845,6 +1969,8 @@ class XPUMultiEncoderFuser {
std::vector<float> fc_input_max;
std::vector<float> softmax_max;
std::vector<std::string> quant_types;
bool has_token_sliced_layer = false;
std::vector<int> token_sliced_length(all_encoders.size(), -1);

for (size_t i = 0; i < all_encoders.size(); ++i) {
Node* cur_encoder = all_encoders[i];
Expand Down Expand Up @@ -1889,6 +2015,12 @@ class XPUMultiEncoderFuser {
}
}

if (op_info->HasAttr("with_token_slice") &&
op_info->HasAttr("token_sliced_length")) {
has_token_sliced_layer = true;
token_sliced_length[i] = op_info->GetAttr<int>("token_sliced_length");
}

auto* cur_out =
graph->RetrieveArgument(op_info->Output("Outputs").front());
if (all_encoders.size() == 1) {
Expand Down Expand Up @@ -1930,6 +2062,9 @@ class XPUMultiEncoderFuser {
}
op_desc.SetOutput("Output", {out_name});
op_desc.SetAttr<int>("xpu", 1);
op_desc.SetAttr<std::vector<int>>("token_sliced_length",
token_sliced_length);
op_desc.SetAttr<bool>("has_token_sliced_layer", has_token_sliced_layer);
op_desc.SetAttr<bool>(
"is_smooth_quant",
first_encoder_op_info->GetAttr<bool>("is_smooth_quant"));
Expand Down Expand Up @@ -2325,6 +2460,7 @@ class XPUMultiEncoderFusePass : public ProgramPass {
std::vector<std::string> relative_embedding_type{
"", "__xpu__roformer_relative_embedding"};
std::vector<bool> with_smooth_quant{true, false};
std::vector<bool> with_token_slice{true, false};

std::string fc_precision;
bool adaptive_seqlen = false;
Expand Down Expand Up @@ -2384,19 +2520,25 @@ class XPUMultiEncoderFusePass : public ProgramPass {
// so remove one
continue;
}
fusion::XPUSingleEncoderFuser single_encoder_fuser(
act_type,
input_pos,
qkv_ln_2_out_pos,
matmul_type,
matmul2_type,
mul_type,
with_q_scale,
norm_before,
relative_type,
mask,
smooth_quant);
single_encoder_fuser(graph.get());
for (auto token_slice : with_token_slice) {
fusion::XPUSingleEncoderFuser single_encoder_fuser(
act_type,
input_pos,
qkv_ln_2_out_pos,
matmul_type,
matmul2_type,
mul_type,
with_q_scale,
norm_before,
relative_type,
mask,
smooth_quant,
token_slice);
single_encoder_fuser(graph.get());
}
// must wait for both cases of whether single_encoders
// have token_slice
// to be detected before multi_encoder detecting.
fusion::XPUMultiEncoderFuser multi_encoder_fuser(
fc_precision, adaptive_seqlen);
multi_encoder_fuser(graph.get());
Expand Down
Loading

0 comments on commit 8c22d8f

Please sign in to comment.