From b1e6c8e9459cad97ab596d2afae29e18e5fbcc1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Thu, 27 Jun 2024 09:00:22 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E4=B8=80=E4=BA=9B=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/graph.h | 2 ++ src/graph.cpp | 32 +++++++++++++++++++++++++++++++- src/models/basellm.cpp | 3 ++- src/models/graphllm.cpp | 13 +++++++++++-- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/include/graph.h b/include/graph.h index 27bf717b..f3d3a04f 100644 --- a/include/graph.h +++ b/include/graph.h @@ -52,6 +52,7 @@ namespace fastllm { void Update(); void AddTo(ComputeGraphNode &input0, ComputeGraphNode &input1, float alpha = 1.0); // input0 += input1 * alpha + void DataTypeAs(ComputeGraphNode &input, ComputeGraphNode &input1); // 将input的dataType设成和input1一样 void Embedding(ComputeGraphNode &input, ComputeGraphNode &weight, ComputeGraphNode &output); void ExpandHead(ComputeGraphNode &input, int headDim); void FusedAttention(ComputeGraphNode &q, ComputeGraphNode &k, ComputeGraphNode &v, @@ -65,6 +66,7 @@ namespace fastllm { void Silu(ComputeGraphNode &input, ComputeGraphNode &output); void Split(ComputeGraphNode &input, int axis, int start, int end, ComputeGraphNode &output); void SplitLastTokenStates(ComputeGraphNode &input, ComputeGraphNode &output); + void Swiglu(ComputeGraphNode &input, ComputeGraphNode &output); // 以下op用于调试 void Exit(); // 退出 diff --git a/src/graph.cpp b/src/graph.cpp index 1862afee..e6cba90c 100644 --- a/src/graph.cpp +++ b/src/graph.cpp @@ -36,6 +36,20 @@ namespace fastllm { auto data = allDatas[op.datas.find("input")->second]; data->ToDevice(DataDevice::CPU); data->Print(); + } else if (op.type == "DataTypeAs") { + auto input = allDatas[op.datas.find("input")->second]; + DataType dataType = allDatas[op.datas.find("input1")->second]->dataType; + if (input->dataType != dataType) { + if (dataType == DataType::FLOAT32) { + excutor.Run("ToFloat32", { + {"input", input} + }, {}, {}); + } else if (dataType == DataType::FLOAT16) { + excutor.Run("ToFloat16", { + {"input", input} + }, {}, {}); + } + } } else if (op.type == "ExpandHeads") { auto data = allDatas[op.datas.find("input")->second]; int headDim = op.intParams.find("headDim")->second; @@ -207,6 +221,14 @@ namespace fastllm { ); } + void ComputeGraph::DataTypeAs(ComputeGraphNode &input, ComputeGraphNode &input1) { + this->ops.push_back ( + ComputeGraphOp("DataTypeAs", + {{"input", input.name}, {"input1", input1.name}}, + {}, {}) + ); + } + void ComputeGraph::MulTo(ComputeGraphNode &input0, ComputeGraphNode &input1) { this->ops.push_back ( ComputeGraphOp("MulTo", @@ -218,7 +240,15 @@ namespace fastllm { void ComputeGraph::Silu(ComputeGraphNode &input, ComputeGraphNode &output) { this->ops.push_back ( ComputeGraphOp("Silu", - {{"input", "w1"}, {"output", "w1"}}, + {{"input", input.name}, {"output", output.name}}, + {}, {}) + ); + } + + void ComputeGraph::Swiglu(ComputeGraphNode &input, ComputeGraphNode &output) { + this->ops.push_back ( + ComputeGraphOp("Swiglu", + {{"input", input.name}, {"output", output.name}}, {}, {}) ); } diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index d8070e34..d2986045 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -896,7 +896,8 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to } else if (dataType == DataType::FLOAT16) { AssertInFastLLM(this->model_struct == "chatglm" || - this->model_struct == "llama", + this->model_struct == "llama" || + this->model_struct == "graph", this->model_struct + " doesn't support float16"); } else { ErrorInFastLLM("SetDataType Error: datatype should be float32 or float16"); diff --git a/src/models/graphllm.cpp b/src/models/graphllm.cpp index 1166af92..1f31a5f6 100644 --- a/src/models/graphllm.cpp +++ b/src/models/graphllm.cpp @@ -114,10 +114,12 @@ namespace fastllm { for (auto &it : weight.weight) { weightDicts[it.first] = &it.second; } + Data atype = Data(this->dataType); std::map inputs = { {"inputIds", (Data*)&inputIds}, {"positionIds", (Data*)&positionIds}, {"attentionMask", (Data*)&attentionMask}, + {"atype", (Data*)&atype}, {"sin", &sinData}, {"cos", &cosData} }; for (int i = 0; i < block_cnt; i++) { @@ -250,10 +252,11 @@ namespace fastllm { for (auto &it : model->weight.weight) { wNodes[it.first] = ComputeGraphNode(it.first); } - ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), sin("sin"), cos("cos"); + ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos"); ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput"); ComputeGraphNode q("q"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits"); graph.Embedding(inputIds, wNodes["model.embed_tokens.weight"], hiddenStates); + graph.DataTypeAs(hiddenStates, atype); for (int i = 0; i < model->block_cnt; i++) { std::string pre = "model.layers." + std::to_string(i); ComputeGraphNode pastKey("pastKey_" + std::to_string(i)), pastValue("pastValue_" + std::to_string(i)); @@ -289,6 +292,11 @@ namespace fastllm { model->max_positions = atoi(model->weight.dicts["seq_length"].c_str()); model->rope_base = 10000 * pow(3, ((float)model->rotary_dim / (model->rotary_dim - 2))); model->rope_factor = 1.0; + + model->pre_prompt = ""; + model->user_role = "<_user>"; + model->bot_role = "<_bot>"; + model->history_sep = ""; } std::map > > @@ -331,10 +339,11 @@ namespace fastllm { for (auto &it : model->weight.weight) { wNodes[it.first] = ComputeGraphNode(it.first); } - ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), sin("sin"), cos("cos"); + ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos"); ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput"); ComputeGraphNode q("q"), kv("kv"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits"); graph.Embedding(inputIds, wNodes["transformer.word_embeddings.weight"], hiddenStates); + graph.DataTypeAs(hiddenStates, atype); for (int i = 0; i < model->block_cnt; i++) { std::string pre = "transformer.h." + std::to_string(i); ComputeGraphNode pastKey("pastKey_" + std::to_string(i)), pastValue("pastValue_" + std::to_string(i));