Skip to content

Commit

Permalink
Merge pull request #331 from fluxlinkage/master
Browse files Browse the repository at this point in the history
增加基本的GLM模型(例如glm-large-chinese)支持
  • Loading branch information
ztxz16 authored Sep 27, 2023
2 parents e9b2f90 + 06987b2 commit 80a3917
Show file tree
Hide file tree
Showing 8 changed files with 691 additions and 7 deletions.
12 changes: 11 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ option(USE_CUDA "use cuda" OFF)
option(PY_API "python api" OFF)
option(USE_MMAP "use mmap" OFF)

option(USE_SENTENCEPIECE "use sentencepiece" OFF)

message(STATUS "USE_CUDA: ${USE_CUDA}")

message(STATUS "PYTHON_API: ${PY_API}")

message(STATUS "USE_SENTENCEPIECE: ${USE_SENTENCEPIECE}")

set(CMAKE_BUILD_TYPE "Release")

if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
Expand All @@ -25,7 +29,7 @@ endif()
message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS})
set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/model.cpp src/executor.cpp
src/devices/cpu/cpudevice.cpp src/devices/cpu/cpudevicebatch.cpp
src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp)
src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp src/models/glm.cpp)

include_directories(include)
include_directories(include/utils)
Expand All @@ -35,6 +39,12 @@ if (USE_MMAP)
add_compile_definitions(USE_MMAP)
endif()

if (USE_SENTENCEPIECE)
set(CMAKE_CXX_STANDARD 17)
add_compile_definitions(USE_SENTENCEPIECE)
set(FASTLLM_LINKED_LIBS ${FASTLLM_LINKED_LIBS} sentencepiece)
endif()

if (USE_CUDA)
enable_language(CUDA)
add_compile_definitions(USE_CUDA)
Expand Down
10 changes: 9 additions & 1 deletion include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include <memory>
#include "devices/cpu/cputhreadpool.h"

#ifdef USE_SENTENCEPIECE
#include <sentencepiece_processor.h>
#endif

namespace fastllm {
void SetDeviceMap(const std::map <std::string, int> &deviceMap);
std::map <std::string, int> GetDeviceMap();
Expand Down Expand Up @@ -308,7 +312,8 @@ namespace fastllm {
enum TokenizerType {
BPE = 0,
NORMAL = 1,
QWEN = 2
QWEN = 2,
GLM = 3
};

struct TrieNode {
Expand Down Expand Up @@ -359,6 +364,9 @@ namespace fastllm {
std::unordered_map <int, std::string> tokenToStringDict;
std::unordered_map <int, float> tokenToScoreDict;
std::unordered_map <std::string, int> stringToTokenDict;
#ifdef USE_SENTENCEPIECE
std::unique_ptr<sentencepiece::SentencePieceProcessor> spProcessor;
#endif

Tokenizer ();

Expand Down
62 changes: 62 additions & 0 deletions include/models/glm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//
// Created by huangyuyang on 5/11/23.
//

#ifndef FASTLLM_GLM_H
#define FASTLLM_GLM_H

#include "basellm.h"
#include "cmath"

#include <iostream>

namespace fastllm {
class GLMModel: public basellm {
public:
GLMModel (); // 构造函数

// 推理
virtual int Forward(
const Data &inputIds,
const Data &attentionMask,
const Data &positionIds,
std::vector <std::pair <Data, Data> > &pastKeyValues,
const GenerationConfig &generationConfig = GenerationConfig(),
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <float> *logits = nullptr);

std::vector <int> ForwardBatch(
int batch,
const Data &inputIds,
const Data &attentionMask,
const Data &positionIds,
std::vector <std::pair <Data, Data> > &pastKeyValues,
const GenerationConfig &generationConfig = GenerationConfig(),
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <std::vector <float>*> *retLogits = nullptr);

// 根据输入的tokens生成LLM推理的输入
virtual void FillLLMInputs(std::vector <std::vector <float> > &inputTokens,
const std::map <std::string, int> &params,
Data &inputIds, Data &attentionMask, Data &positionIds);

virtual void InitParams();
virtual void WarmUp(); // 预热

virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt

virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history

private:

float scale_attn_1;

static constexpr int eot_token_id = 50000;//<|endoftext|>
static constexpr int cls_token_id = 50002;//[CLS]
static constexpr int mask_token_id = 50003;//[MASK]
static constexpr int smask_token_id = 50008;//[sMASK]
static constexpr int gmask_token_id = 50009;//[gMASK]
};
}

#endif //FASTLLM_GLM_H
128 changes: 127 additions & 1 deletion src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,132 @@ namespace fastllm {
}
}
return Data (DataType::FLOAT32, {1, (int)v.size()}, v);
} else if (this->type == TokenizerType::GLM) {
const std::map<std::string, int> specialTokens = {{"[MASK]", 50003}, {"[sMASK]", 50008}, {"[gMASK]", 50009}};
std::string blank = "";
blank += 226, blank += 150, blank += 129;
std::string s = blank;
for (int i = 0; i < ori.size(); i++) {
if (ori[i] == ' ') {
if (i != 0 && ori[i - 1] != ' ') {
s += blank;
}
} else {
s += ori[i];
}
}
std::vector<float> v;
int findPos=0;
while(findPos<s.length()){
int nextSpecialToken=-1;
int nextSpecialTokenPos=-1;
int nextSpecialTokenLen=-1;
for(auto p:specialTokens){
int ind=s.find(p.first,findPos);
if(ind>=0&&(nextSpecialTokenPos<0||ind<nextSpecialTokenPos)){
nextSpecialTokenPos=ind;
nextSpecialToken=p.second;
nextSpecialTokenLen=p.first.length();
}
}
std::string subStr;
if(nextSpecialTokenPos<0){
subStr=s.substr(findPos);
findPos=s.length();
}else{
subStr=s.substr(findPos,nextSpecialTokenPos-findPos);
findPos=nextSpecialTokenPos+nextSpecialTokenLen;
}
if(subStr.length()>0){
#ifdef USE_SENTENCEPIECE
if(spProcessor!=nullptr){
std::vector<int> ids;
spProcessor->Encode(subStr,&ids);
for(int id:ids){
v.push_back(id);
}
}else{
#endif
std::vector<Symbol> symbols;
for (int i = 0; i < subStr.size(); i++) {
int tokenId = -999999, pos = i - 1;
TrieNode *now = this->root;
for (int j = i; j < subStr.size(); j++) {
if (now->next.find(subStr[j]) != now->next.end()) {
now = now->next[subStr[j]];
if (now->tokenId != -999999) {
tokenId = now->tokenId;
pos = j;
break;
}
} else {
break;
}
}
if (pos >= i) {
symbols.push_back(Symbol(now, (char *) subStr.data(), i, pos - i + 1, (int) symbols.size() - 1,
(int) symbols.size() + 1, -999999));
i = pos;
} else {
symbols.push_back(Symbol(nullptr, (char *) subStr.data(), i, 0, (int) symbols.size() - 1,
(int) symbols.size() + 1, -999999));
}
}
symbols.back().next = -1;

std::priority_queue<SymbolPairs> workQueue;
for (int i = 1; i < symbols.size(); i++) {
TryMergePairs(symbols, i - 1, i, workQueue);
}

while (!workQueue.empty()) {
auto top = workQueue.top();
workQueue.pop();
if (symbols[top.l].len == 0 || symbols[top.r].len == 0 ||
symbols[top.l].len + symbols[top.r].len != top.size) {
continue;
}

for (int i = symbols[top.r].pos; i < symbols[top.r].pos + symbols[top.r].len; i++) {
symbols[top.l].node = symbols[top.l].node->next[symbols[top.r].s[i]];
}
symbols[top.l].len += symbols[top.r].len;
symbols[top.r].len = 0;
symbols[top.l].next = symbols[top.r].next;
if (symbols[top.r].next >= 0) {
symbols[symbols[top.r].next].prev = top.l;
}

TryMergePairs(symbols, symbols[top.l].prev, top.l, workQueue);
TryMergePairs(symbols, top.l, symbols[top.l].next, workQueue);
}
for (int i = 0; i < symbols.size(); i++) {
if (symbols[i].len > 0) {
v.push_back(symbols[i].node->tokenId);
} else if (symbols[i].node == nullptr) {
if (symbols[i].fixId != -999999) {
v.push_back(symbols[i].fixId);
} else {
// 未识别的字符
uint8_t c = (uint8_t) (symbols[i].s[symbols[i].pos]);
std::string now = "<0x00>";
now[3] = (c / 16 > 9 ? ('A' + c / 16 - 10) : ('0' + c / 16));
now[4] = (c % 16 > 9 ? ('A' + c % 16 - 10) : ('0' + c % 16));
if (stringToTokenDict.find(now) != stringToTokenDict.end()) {
v.push_back(stringToTokenDict[now]);
}
}
}
}
#ifdef USE_SENTENCEPIECE
}
#endif
}
if(nextSpecialTokenPos>=0){
v.push_back(nextSpecialToken);
}
}
return Data (DataType::FLOAT32, {1, (int)v.size()}, v);
} else if (this->type == TokenizerType::QWEN) {
std::map<std::string, int> specialTokens = {{"<|im_start|>", 151644}, {"<|im_end|>", 151645}, {"<|endoftext|>", 151643}};

Expand Down Expand Up @@ -1960,4 +2086,4 @@ namespace fastllm {
std::map <std::string, int> GetDeviceMap() {
return defaultDeviceMap;
}
}
}
15 changes: 12 additions & 3 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "moss.h"
#include "llama.h"
#include "qwen.h"
#include "glm.h"

namespace fastllm {
void basellm::LoadFromFile(const std::string &fileName) {
Expand All @@ -16,15 +17,21 @@ namespace fastllm {

void basellm::InitParams() {
if (this->weight.dicts.find("bos_token_id") != this->weight.dicts.end()) {
this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str());
this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str());
if(this->weight.dicts["bos_token_id"]!="None"){
this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str());
}
if(this->weight.dicts["eos_token_id"]!="None"){
this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str());
}
}
if (this->weight.dicts.find("im_start_id") != this->weight.dicts.end()) {
this->bos_token_id = atoi(this->weight.dicts["im_start_id"].c_str());
this->eos_token_id = atoi(this->weight.dicts["im_end_id"].c_str());
}
if (this->weight.dicts.find("num_hidden_layers") != this->weight.dicts.end()) {
block_cnt = atoi(this->weight.dicts["num_hidden_layers"].c_str());
}else if (this->weight.dicts.find("num_layers") != this->weight.dicts.end()) {
block_cnt = atoi(this->weight.dicts["num_layers"].c_str());
}
if (this->weight.dicts.find("hidden_size") != this->weight.dicts.end()) {
embed_dim = atoi(this->weight.dicts["hidden_size"].c_str());
Expand Down Expand Up @@ -77,6 +84,8 @@ namespace fastllm {
} else if (modelType == "qwen") {
model = (basellm *) (new QWenModel());
model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
} else if (modelType == "glm") {
model = (basellm*)(new GLMModel());
} else {
ErrorInFastLLM("Unkown model type: " + modelType);
}
Expand All @@ -95,4 +104,4 @@ namespace fastllm {
basellm *model = CreateModelWithType(modelType);
return std::unique_ptr<fastllm::basellm> (model);
}
}
}
2 changes: 1 addition & 1 deletion src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,4 +777,4 @@ printf("tot = %d\n", tot);
void basellm::DisableAdapter() {
adapterName = "";
}
}
}
Loading

0 comments on commit 80a3917

Please sign in to comment.