Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] support token-slice for encoder #10470

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ USE_MIR_PASS(__xpu__spatial_transformer_resblock_fuse_pass);
USE_MIR_PASS(__xpu__matmul_scale_softmax_v1_fuse_pass);
USE_MIR_PASS(__xpu__up_decoder_fuse_pass);
USE_MIR_PASS(__xpu__multi_up_decoder_fuse_pass);
USE_MIR_PASS(__xpu__remove_mask_slice_pass);
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v2_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v3_fuse_pass);
Expand Down
196 changes: 169 additions & 27 deletions lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class XPUSingleEncoderFuser : public FuseBase {
bool norm_before = false,
const std::string& relative_type = "",
bool with_mask = true,
bool smooth_quant = false)
bool smooth_quant = false,
bool with_token_slice = false)
: act_type_(act_type),
input_pos_(input_pos),
qkv_ln_2_out_pos_(qkv_ln_2_out_pos),
Expand All @@ -66,7 +67,8 @@ class XPUSingleEncoderFuser : public FuseBase {
norm_before_(norm_before),
relative_emb_type_(relative_type),
with_mask_(with_mask),
smooth_quant_(smooth_quant) {}
smooth_quant_(smooth_quant),
with_token_slice_(with_token_slice) {}

void BuildPattern() override {
PMNode* input = nullptr;
Expand All @@ -79,11 +81,39 @@ class XPUSingleEncoderFuser : public FuseBase {
PMNode* ln_before_out = nullptr;
PMNode* ln_before_mean = nullptr;
PMNode* ln_before_var = nullptr;
PMNode* input_slice = nullptr;
PMNode* input_slice_out = nullptr;
if (smooth_quant_ && !norm_before_) {
VLOG(3) << "build first smooth_quant_scale";
input = VarNode("input")
->assert_is_op_input("elementwise_mul", "X")
->AsInput();
if (with_token_slice_) {
VLOG(3) << "build input_slice";
input =
VarNode("input")->assert_is_op_input("slice", "Input")->AsInput();
input_slice = OpNode("input_slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) {
return attr.size() == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"ends",
[](const std::vector<int>& attr) {
return attr.size() == 1;
})
->AsIntermediate();
input_slice_out = VarNode("input_slice_out")
->assert_is_op_output("slice", "Out")
->assert_is_op_input("elementwise_mul", "X");
} else {
input = VarNode("input")
->assert_is_op_input("elementwise_mul", "X")
->AsInput();
}
smooth_scale_1_weight = VarNode("smooth_scale_1_weight")
->assert_is_op_input("elementwise_mul", "Y")
->AsInput();
Expand All @@ -92,6 +122,26 @@ class XPUSingleEncoderFuser : public FuseBase {
smooth_scale_1_out = VarNode("smooth_scale_1_out")
->assert_is_op_output("elementwise_mul", "Out")
->AsIntermediate();
} else if (with_token_slice_) {
VLOG(3) << "build input_slice";
input = VarNode("input")->assert_is_op_input("slice", "Input")->AsInput();
input_slice =
OpNode("input_slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) { return attr.size() == 1; })
->assert_op_attr_satisfied<std::vector<int>>(
"ends",
[](const std::vector<int>& attr) { return attr.size() == 1; })
->AsIntermediate();
input_slice_out = VarNode("input_slice_out")
->assert_is_op_output("slice", "Out")
->assert_is_op_input("elementwise_add", input_pos_);
} else {
input = VarNode("input")
->assert_is_op_input("elementwise_add", input_pos_)
Expand Down Expand Up @@ -311,11 +361,40 @@ class XPUSingleEncoderFuser : public FuseBase {
VarNode("qkv_transpose2_xshape")
->assert_is_op_output("transpose2", "XShape")
->AsIntermediate();
PMNode* qkv_slice = nullptr;
PMNode* qkv_slice_out = nullptr;
PMNode* qkv_reshape2_out = nullptr;
auto* qkv_reshape2 = OpNode("qkv_reshape2", "reshape2")->AsIntermediate();
auto* qkv_reshape2_out = VarNode("qkv_reshape2_out")
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input(mul_type_, "X")
->AsIntermediate();
if (with_token_slice_) {
qkv_reshape2_out = VarNode("qkv_reshape2_out")
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("slice", "Input")
->AsIntermediate();
VLOG(3) << "build qkv_slice";
qkv_slice =
OpNode("qkv_slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) { return attr.size() == 1; })
->assert_op_attr_satisfied<std::vector<int>>(
"ends",
[](const std::vector<int>& attr) { return attr.size() == 1; })
->AsIntermediate();
qkv_slice_out = VarNode("qkv_slice_out")
->assert_is_op_output("slice", "Out")
->assert_is_op_input(mul_type_, "X")
->AsIntermediate();
} else {
qkv_reshape2_out = VarNode("qkv_reshape2_out")
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input(mul_type_, "X")
->AsIntermediate();
}
auto* qkv_reshape2_xshape = VarNode("qkv_reshape2_xshape")
->assert_is_op_output("reshape2", "XShape")
->AsIntermediate();
Expand Down Expand Up @@ -531,18 +610,33 @@ class XPUSingleEncoderFuser : public FuseBase {
*v_transpose2 >> *v_transpose2_xshape;

*qkv_matmul >> *qkv_matmul_out >> *qkv_transpose2 >> *qkv_transpose2_out >>
*qkv_reshape2 >> *qkv_reshape2_out >> *qkv_mul >> *qkv_mul_out >>
*qkv_add >> *qkv_add_out >> *qkv_add_2;
*qkv_reshape2 >> *qkv_reshape2_out;
if (with_token_slice_) {
*qkv_reshape2_out >> *qkv_slice >> *qkv_slice_out >> *qkv_mul;
} else {
*qkv_reshape2_out >> *qkv_mul;
}
*qkv_mul >> *qkv_mul_out >> *qkv_add >> *qkv_add_out >> *qkv_add_2;
*qkv_transpose2 >> *qkv_transpose2_xshape;
*qkv_reshape2 >> *qkv_reshape2_xshape;
*qkv_mul_y >> *qkv_mul;
*qkv_add_y >> *qkv_add;
if (smooth_quant_ && !norm_before_) {
*smooth_scale_1_weight >> *smooth_scale_1;
*input >> *smooth_scale_1 >> *smooth_scale_1_out >> *qkv_add_2 >>
*qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
if (with_token_slice_) {
*input >> *input_slice >> *input_slice_out >> *smooth_scale_1;
} else {
*input >> *smooth_scale_1;
}
*smooth_scale_1 >> *smooth_scale_1_out >> *qkv_add_2 >> *qkv_add_2_out >>
*qkv_ln_2 >> *qkv_ln_2_out;
} else {
*input >> *qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
if (with_token_slice_) {
*input >> *input_slice >> *input_slice_out >> *qkv_add_2;
} else {
*input >> *qkv_add_2;
}
*qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
}
*qkv_ln_2_scale >> *qkv_ln_2;
*qkv_ln_2_bias >> *qkv_ln_2;
Expand Down Expand Up @@ -619,6 +713,8 @@ class XPUSingleEncoderFuser : public FuseBase {
// the model is smooth_quant or not, we don't need to do anything
// so, set is_smooth_quant as false.
op_desc.SetAttr<bool>("is_smooth_quant", false);
// temporarily does not support token slice in the case of pre-layernorm
op_desc.SetAttr<bool>("with_token_slice", false);
} else {
op_desc.SetInput("LNScale",
{
Expand All @@ -643,6 +739,33 @@ class XPUSingleEncoderFuser : public FuseBase {
} else {
op_desc.SetAttr<bool>("is_smooth_quant", false);
}
if (with_token_slice_) {
op_desc.SetAttr<bool>("with_token_slice", true);
int token_sliced_length = -1;
auto* qkv_slice_op_info = matched.at("qkv_slice")->stmt()->op_info();
auto* input_slice_op_info =
matched.at("input_slice")->stmt()->op_info();
if (qkv_slice_op_info->HasAttr("starts") &&
qkv_slice_op_info->HasAttr("ends") &&
input_slice_op_info->HasAttr("starts") &&
input_slice_op_info->HasAttr("ends")) {
auto qkv_slice_starts =
qkv_slice_op_info->GetAttr<std::vector<int>>("starts");
auto qkv_slice_ends =
qkv_slice_op_info->GetAttr<std::vector<int>>("ends");
auto input_slice_starts =
input_slice_op_info->GetAttr<std::vector<int>>("starts");
auto input_slice_ends =
input_slice_op_info->GetAttr<std::vector<int>>("ends");
CHECK_EQ(qkv_slice_starts.size(), input_slice_starts.size());
CHECK_EQ(qkv_slice_ends.size(), input_slice_ends.size());
CHECK_EQ(qkv_slice_starts[0], input_slice_starts[0]);
CHECK_EQ(qkv_slice_ends[0], input_slice_ends[0]);
token_sliced_length = qkv_slice_ends[0] - qkv_slice_starts[0];
CHECK_GT(token_sliced_length, 0);
op_desc.SetAttr<int>("token_sliced_length", token_sliced_length);
}
}
}
// XXX: keep these to fool SubgraphOp::AttachImpl()
op_desc.SetAttr<int>("sub_block", 0);
Expand Down Expand Up @@ -792,6 +915,7 @@ class XPUSingleEncoderFuser : public FuseBase {
const std::string relative_emb_type_;
bool with_mask_;
bool smooth_quant_;
bool with_token_slice_;
// quant_info: mul input_max, output_max * 6 + matmul x_max:y_max,
// output_max
// * 2
Expand Down Expand Up @@ -1845,6 +1969,8 @@ class XPUMultiEncoderFuser {
std::vector<float> fc_input_max;
std::vector<float> softmax_max;
std::vector<std::string> quant_types;
bool has_token_sliced_layer = false;
std::vector<int> token_sliced_length(all_encoders.size(), -1);

for (size_t i = 0; i < all_encoders.size(); ++i) {
Node* cur_encoder = all_encoders[i];
Expand Down Expand Up @@ -1889,6 +2015,12 @@ class XPUMultiEncoderFuser {
}
}

if (op_info->HasAttr("with_token_slice") &&
op_info->HasAttr("token_sliced_length")) {
has_token_sliced_layer = true;
token_sliced_length[i] = op_info->GetAttr<int>("token_sliced_length");
}

auto* cur_out =
graph->RetrieveArgument(op_info->Output("Outputs").front());
if (all_encoders.size() == 1) {
Expand Down Expand Up @@ -1930,6 +2062,9 @@ class XPUMultiEncoderFuser {
}
op_desc.SetOutput("Output", {out_name});
op_desc.SetAttr<int>("xpu", 1);
op_desc.SetAttr<std::vector<int>>("token_sliced_length",
token_sliced_length);
op_desc.SetAttr<bool>("has_token_sliced_layer", has_token_sliced_layer);
op_desc.SetAttr<bool>(
"is_smooth_quant",
first_encoder_op_info->GetAttr<bool>("is_smooth_quant"));
Expand Down Expand Up @@ -2325,6 +2460,7 @@ class XPUMultiEncoderFusePass : public ProgramPass {
std::vector<std::string> relative_embedding_type{
"", "__xpu__roformer_relative_embedding"};
std::vector<bool> with_smooth_quant{true, false};
std::vector<bool> with_token_slice{true, false};

std::string fc_precision;
bool adaptive_seqlen = false;
Expand Down Expand Up @@ -2384,19 +2520,25 @@ class XPUMultiEncoderFusePass : public ProgramPass {
// so remove one
continue;
}
fusion::XPUSingleEncoderFuser single_encoder_fuser(
act_type,
input_pos,
qkv_ln_2_out_pos,
matmul_type,
matmul2_type,
mul_type,
with_q_scale,
norm_before,
relative_type,
mask,
smooth_quant);
single_encoder_fuser(graph.get());
for (auto token_slice : with_token_slice) {
fusion::XPUSingleEncoderFuser single_encoder_fuser(
act_type,
input_pos,
qkv_ln_2_out_pos,
matmul_type,
matmul2_type,
mul_type,
with_q_scale,
norm_before,
relative_type,
mask,
smooth_quant,
token_slice);
single_encoder_fuser(graph.get());
}
// must wait for both cases of whether single_encoders
// have token_slice
// to be detected before multi_encoder detecting.
fusion::XPUMultiEncoderFuser multi_encoder_fuser(
fc_precision, adaptive_seqlen);
multi_encoder_fuser(graph.get());
Expand Down
Loading