Skip to content

Commit

Permalink
moe结构加速
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed May 17, 2024
1 parent 81c3f3e commit 288b94b
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 65 deletions.
2 changes: 2 additions & 0 deletions include/models/deepseekv2.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ namespace fastllm {
float rope_scaling_mscale_all_dim;
float rope_scaling_original_max_position_embeddings;
std::string rope_scaling_type;

bool mergeSwiglu = false;
};
}

Expand Down
89 changes: 77 additions & 12 deletions src/models/deepseekv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,53 @@ namespace fastllm {
const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues,
const GenerationConfig &generationConfig, const LastTokensManager &lastTokens,
std::vector <std::vector <float>*> *retLogits) {
if (!mergeSwiglu) {
bool canMerge = true;
for (int i = 0; i < block_cnt; i++) {
for (int j = -1; j < this->num_experts; j++) {
std::string w1WeightName = "model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".gate_proj.weight";
std::string w3WeightName = "model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".up_proj.weight";
std::string swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".gateup_proj.weight";
if (j == -1) {
w1WeightName = "model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight";
w3WeightName = "model.layers." + std::to_string(i) + ".mlp.shared_experts.up_proj.weight";
swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight";
}

if (weight.weight.find(w1WeightName) == weight.weight.end()) {
continue;
}

Data &w1 = weight.weight[w1WeightName], &w3 = weight.weight[w3WeightName];
if ((w1.dataType == DataType::INT4_GROUP && w1.dims[1] % w1.groupCnt != 0) ||
(w3.dataType == DataType::INT4_GROUP && w3.dims[1] % w3.groupCnt != 0)) {
canMerge = false;
break;
}

weight.weight[swigluWeightName] = Data(w1.dataType, {w1.dims[0] + w3.dims[0], w1.dims[1]});
Data &swiglu = weight.weight[swigluWeightName];
swiglu.name = swigluWeightName;
swiglu.Allocate();
memcpy(swiglu.cpuData, w1.cpuData, w1.GetBytes());
memcpy(swiglu.cpuData + w1.GetBytes(), w3.cpuData, w3.GetBytes());

swiglu.perChannelAxis = w1.perChannelAxis;
swiglu.group = w1.group;
swiglu.groupCnt = w1.groupCnt;
swiglu.perChannelsConfigs = AppendVector(w1.perChannelsConfigs, w3.perChannelsConfigs);
swiglu.zeros = AppendVector(w1.zeros, w3.zeros);
swiglu.scales = AppendVector(w1.scales, w3.scales);
swiglu.mins = AppendVector(w1.mins, w3.mins);

weight.weight.erase(w1WeightName);
weight.weight.erase(w3WeightName);
}

this->mergeSwiglu = canMerge;
}
}

Data alibiData;
if (this->weight.dicts["use_alibi"] == "1") {
std::vector<float> alibi = GetInterleave(num_attention_heads);
Expand Down Expand Up @@ -386,26 +433,44 @@ namespace fastllm {
int idx = (int)(gateData[(b * this->num_experts_per_tok + j) * 2] + 1e-1);
float value = gateData[(b * this->num_experts_per_tok + j) * 2 + 1];
value *= routed_scaling_factor;
if (CanRunLinearEx(LinearExType::ExSilu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
if (this->mergeSwiglu) {
if (CanRunLinearEx(LinearExType::ExSwiglu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w3);
Swiglu(w3, w1);
}
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1);
Silu(w1, w1);
if (CanRunLinearEx(LinearExType::ExSilu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1);
Silu(w1, w1);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".up_proj.weight"], Data(), w3);
MulTo(w1, w3);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".up_proj.weight"], Data(), w3);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".down_proj.weight"], Data(), w2);
AddTo(moePart, w2, value);
}

if (CanRunLinearEx(LinearExType::ExSilu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
if (this->mergeSwiglu) {
if (CanRunLinearEx(LinearExType::ExSwiglu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w3);
Swiglu(w3, w1);
}
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1);
Silu(w1, w1);
if (CanRunLinearEx(LinearExType::ExSilu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1);
Silu(w1, w1);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.up_proj.weight"], Data(), w3);
MulTo(w1, w3);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.up_proj.weight"], Data(), w3);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.down_proj.weight"], Data(), w2);
AddTo(moePart, w2);

Expand Down
126 changes: 73 additions & 53 deletions src/models/moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,40 +181,51 @@ namespace fastllm {
this->mergeQKV = canMerge;
}

if (!mergeSwiglu && CanRunLinearEx(LinearExType::ExSwiglu) && false) {
if (!mergeSwiglu) {
bool canMerge = true;
for (int i = 0; i < block_cnt; i++) {
std::string w1WeightName = "model.layers." + std::to_string(i) + ".mlp.gate_proj.weight";
std::string w3WeightName = "model.layers." + std::to_string(i) + ".mlp.up_proj.weight";
std::string swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.gateup_proj.weight";
for (int j = -1; j < this->num_experts; j++) {
std::string w1WeightName = "model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".gate_proj.weight";
std::string w3WeightName = "model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".up_proj.weight";
std::string swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".gateup_proj.weight";
if (j == -1) {
w1WeightName = "model.layers." + std::to_string(i) + ".mlp.shared_expert.gate_proj.weight";
w3WeightName = "model.layers." + std::to_string(i) + ".mlp.shared_expert.up_proj.weight";
swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.shared_expert.gateup_proj.weight";
}

Data &w1 = weight.weight[w1WeightName], &w3 = weight.weight[w3WeightName];
if ((w1.dataType == DataType::INT4_GROUP && w1.dims[1] % w1.groupCnt != 0) ||
(w3.dataType == DataType::INT4_GROUP && w3.dims[1] % w3.groupCnt != 0)) {
canMerge = false;
break;
}
if (weight.weight.find(w1WeightName) == weight.weight.end()) {
continue;
}

weight.weight[swigluWeightName] = Data(w1.dataType, {w1.dims[0] + w3.dims[0], w1.dims[1]});
Data &swiglu = weight.weight[swigluWeightName];
swiglu.name = swigluWeightName;
swiglu.Allocate();
memcpy(swiglu.cpuData, w1.cpuData, w1.GetBytes());
memcpy(swiglu.cpuData + w1.GetBytes(), w3.cpuData, w3.GetBytes());

swiglu.perChannelAxis = w1.perChannelAxis;
swiglu.group = w1.group;
swiglu.groupCnt = w1.groupCnt;
swiglu.perChannelsConfigs = AppendVector(w1.perChannelsConfigs, w3.perChannelsConfigs);
swiglu.zeros = AppendVector(w1.zeros, w3.zeros);
swiglu.scales = AppendVector(w1.scales, w3.scales);
swiglu.mins = AppendVector(w1.mins, w3.mins);

weight.weight.erase(w1WeightName);
weight.weight.erase(w3WeightName);
}
Data &w1 = weight.weight[w1WeightName], &w3 = weight.weight[w3WeightName];
if ((w1.dataType == DataType::INT4_GROUP && w1.dims[1] % w1.groupCnt != 0) ||
(w3.dataType == DataType::INT4_GROUP && w3.dims[1] % w3.groupCnt != 0)) {
canMerge = false;
break;
}

weight.weight[swigluWeightName] = Data(w1.dataType, {w1.dims[0] + w3.dims[0], w1.dims[1]});
Data &swiglu = weight.weight[swigluWeightName];
swiglu.name = swigluWeightName;
swiglu.Allocate();
memcpy(swiglu.cpuData, w1.cpuData, w1.GetBytes());
memcpy(swiglu.cpuData + w1.GetBytes(), w3.cpuData, w3.GetBytes());

swiglu.perChannelAxis = w1.perChannelAxis;
swiglu.group = w1.group;
swiglu.groupCnt = w1.groupCnt;
swiglu.perChannelsConfigs = AppendVector(w1.perChannelsConfigs, w3.perChannelsConfigs);
swiglu.zeros = AppendVector(w1.zeros, w3.zeros);
swiglu.scales = AppendVector(w1.scales, w3.scales);
swiglu.mins = AppendVector(w1.mins, w3.mins);

weight.weight.erase(w1WeightName);
weight.weight.erase(w3WeightName);
}

this->mergeSwiglu = canMerge;
this->mergeSwiglu = canMerge;
}
}

Data alibiData;
Expand All @@ -238,6 +249,7 @@ namespace fastllm {
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
bool canRunExSilu = CanRunLinearEx(LinearExType::ExSilu);
bool canRunExSwiglu = CanRunLinearEx(LinearExType::ExSwiglu);

RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"],
rms_norm_eps, attenInput);
Expand Down Expand Up @@ -378,18 +390,7 @@ namespace fastllm {
AddTo(hiddenStates, attenLastOutput);
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], rms_norm_eps, attenInput);
// 2. moe mlp
if (this->mergeSwiglu) {
// 这里是正常mlp
std::string swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.gateup_proj.weight";
if (CanRunLinearEx(LinearExType::ExSwiglu)) {
LinearEx(attenInput, weight[swigluWeightName], Data(), w1, LinearExType::ExSwiglu);
} else {
Linear(attenInput, weight[swigluWeightName], Data(), w3);
Swiglu(w3, w1);
}
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2);
AddTo(hiddenStates, w2);
} else {
{
// 这里是moe mlp
std::string gateWeightName = "model.layers." + std::to_string(i) + ".mlp.gate.weight";
int batch = attenInput.dims[0], len = attenInput.dims[1];
Expand All @@ -398,6 +399,7 @@ namespace fastllm {
Softmax(routerLogits, routerLogits, -1);
TopK(routerLogits, gate, this->num_experts_per_tok);
if (batch * len > 1) {
moeFinal = Data();
moeFinal.Resize({0, attenInput.dims[1]});
moeFinal.Expansion(attenInput.dims);
}
Expand All @@ -416,26 +418,44 @@ namespace fastllm {
for (int j = 0; j < this->num_experts_per_tok; j++) {
int idx = (int)(gateData[(b * this->num_experts_per_tok + j) * 2] + 1e-1);
float value = gateData[(b * this->num_experts_per_tok + j) * 2 + 1];
if (canRunExSilu) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
if (this->mergeSwiglu) {
if (canRunExSwiglu) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w3);
Swiglu(w3, w1);
}
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1);
Silu(w1, w1);
if (canRunExSilu) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1);
Silu(w1, w1);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".up_proj.weight"], Data(), w3);
MulTo(w1, w3);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".up_proj.weight"], Data(), w3);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".down_proj.weight"], Data(), w2);
AddTo(moePart, w2, value);
}

if (canRunExSilu) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert.gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
if (this->mergeSwiglu) {
if (canRunExSwiglu) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert.gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert.gateup_proj.weight"], Data(), w3);
Swiglu(w3, w1);
}
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert.gate_proj.weight"], Data(), w1);
Silu(w1, w1);
if (canRunExSilu) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert.gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert.gate_proj.weight"], Data(), w1);
Silu(w1, w1);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert.up_proj.weight"], Data(), w3);
MulTo(w1, w3);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert.up_proj.weight"], Data(), w3);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert.down_proj.weight"], Data(), w2);
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_expert_gate.weight"], Data(), sharedGate);
sharedGate.ToDevice(DataDevice::CPU);
Expand Down

0 comments on commit 288b94b

Please sign in to comment.