Skip to content

Commit

Permalink
moe llama deepseekv2使用统一的fillinput
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed May 18, 2024
1 parent e2d9388 commit 5a2cd63
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 109 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ token
/cmake-build-debug/
/build-tfacc/
/build-android/
/build-x86/
/build-py/
/build/
/pyfastllm/build/
Expand Down
5 changes: 0 additions & 5 deletions include/models/deepseekv2.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@ namespace fastllm {
const std::vector <GenerationConfig> &generationConfigs,
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <std::vector <float>*> *logits = nullptr);

// 根据输入的tokens生成LLM推理的输入
virtual void FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Data &inputIds, Data &attentionMask, Data &positionIds);

// 根据输入的tokens生成LLM推理的输入
virtual void FillLLMInputsBatch(std::vector <std::vector <float> > &inputTokens,
Expand Down
5 changes: 0 additions & 5 deletions include/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ namespace fastllm {
const std::vector <GenerationConfig> &generationConfigs,
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <std::vector <float>*> *logits = nullptr);

// 根据输入的tokens生成LLM推理的输入
virtual void FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Data &inputIds, Data &attentionMask, Data &positionIds);

// 根据输入的tokens生成LLM推理的输入
virtual void FillLLMInputsBatch(std::vector <std::vector <float> > &inputTokens,
Expand Down
5 changes: 0 additions & 5 deletions include/models/moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ namespace fastllm {
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <std::vector <float>*> *logits = nullptr);

// 根据输入的tokens生成LLM推理的输入
virtual void FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Data &inputIds, Data &attentionMask, Data &positionIds);

// 根据输入的tokens生成LLM推理的输入
virtual void FillLLMInputsBatch(std::vector <std::vector <float> > &inputTokens,
const std::vector <std::map <std::string, int> > &params,
Expand Down
26 changes: 26 additions & 0 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,32 @@ printf("tot = %d\n", tot);
void basellm::FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Data &inputIds, Data &attentionMask, Data &positionIds) {
inputIds.ToDevice(DataDevice::CPU);
attentionMask.ToDevice(DataDevice::CPU);
positionIds.ToDevice(DataDevice::CPU);

int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second;

if (index == 0) {
int seqLen = inputTokens[0].size();

std::vector <float> vmask = std::vector <float> (seqLen * seqLen, 0);
std::vector <float> vpids = std::vector <float> (seqLen, 0);
for (int i = 0; i < seqLen; i++) {
vpids[i] = i;
for (int j = i + 1; j < seqLen; j++) {
vmask[i * seqLen + j] = 1;
}
}
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0]));
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, seqLen}, vmask));
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids));
} else {
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0]));
attentionMask = Data();
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, {(float) promptLen + index - 1}));
}
}

// 根据输入的tokens生成LLM推理的输入
Expand Down
31 changes: 0 additions & 31 deletions src/models/deepseekv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,37 +773,6 @@ namespace fastllm {
return lastRet;
}

void DeepSeekV2Model::FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Data &inputIds, Data &attentionMask, Data &positionIds) {
inputIds.ToDevice(DataDevice::CPU);
attentionMask.ToDevice(DataDevice::CPU);
positionIds.ToDevice(DataDevice::CPU);

int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second;

if (index == 0) {
int seqLen = inputTokens[0].size();

std::vector <float> vmask = std::vector <float> (seqLen * seqLen, 0);
std::vector <float> vpids = std::vector <float> (seqLen, 0);
for (int i = 0; i < seqLen; i++) {
vpids[i] = i;
for (int j = i + 1; j < seqLen; j++) {
vmask[i * seqLen + j] = 1;
}
}
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0]));
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, seqLen}, vmask));
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids));
} else {
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0]));
attentionMask = Data();
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, {(float) promptLen + index - 1}));
}
}

void DeepSeekV2Model::FillLLMInputsBatch(std::vector<std::vector<float>> &inputTokens,
const std::vector<std::map<std::string, int>> &params,
fastllm::Data &inputIds, fastllm::Data &attentionMask,
Expand Down
33 changes: 1 addition & 32 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ namespace fastllm {
this->mergeQKV = canMerge;
}

if (!mergeSwiglu && CanRunLinearEx(LinearExType::ExSwiglu)) {
if (!mergeSwiglu) {
bool canMerge = true;
for (int i = 0; i < block_cnt; i++) {
std::string w1WeightName = "model.layers." + std::to_string(i) + ".mlp.gate_proj.weight";
Expand Down Expand Up @@ -706,37 +706,6 @@ namespace fastllm {
return lastRet;
}

void LlamaModel::FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Data &inputIds, Data &attentionMask, Data &positionIds) {
inputIds.ToDevice(DataDevice::CPU);
attentionMask.ToDevice(DataDevice::CPU);
positionIds.ToDevice(DataDevice::CPU);

int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second;

if (index == 0) {
int seqLen = inputTokens[0].size();

std::vector <float> vmask = std::vector <float> (seqLen * seqLen, 0);
std::vector <float> vpids = std::vector <float> (seqLen, 0);
for (int i = 0; i < seqLen; i++) {
vpids[i] = i;
for (int j = i + 1; j < seqLen; j++) {
vmask[i * seqLen + j] = 1;
}
}
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0]));
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, seqLen}, vmask));
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids));
} else {
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0]));
attentionMask = Data();
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, {(float) promptLen + index - 1}));
}
}

void LlamaModel::FillLLMInputsBatch(std::vector<std::vector<float>> &inputTokens,
const std::vector<std::map<std::string, int>> &params,
fastllm::Data &inputIds, fastllm::Data &attentionMask,
Expand Down
31 changes: 0 additions & 31 deletions src/models/moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,37 +768,6 @@ namespace fastllm {
return lastRet;
}

void MoeModel::FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Data &inputIds, Data &attentionMask, Data &positionIds) {
inputIds.ToDevice(DataDevice::CPU);
attentionMask.ToDevice(DataDevice::CPU);
positionIds.ToDevice(DataDevice::CPU);

int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second;

if (index == 0) {
int seqLen = inputTokens[0].size();

std::vector <float> vmask = std::vector <float> (seqLen * seqLen, 0);
std::vector <float> vpids = std::vector <float> (seqLen, 0);
for (int i = 0; i < seqLen; i++) {
vpids[i] = i;
for (int j = i + 1; j < seqLen; j++) {
vmask[i * seqLen + j] = 1;
}
}
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0]));
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, seqLen}, vmask));
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids));
} else {
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0]));
attentionMask = Data();
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, {(float) promptLen + index - 1}));
}
}

void MoeModel::FillLLMInputsBatch(std::vector<std::vector<float>> &inputTokens,
const std::vector<std::map<std::string, int>> &params,
fastllm::Data &inputIds, fastllm::Data &attentionMask,
Expand Down

0 comments on commit 5a2cd63

Please sign in to comment.