From 37853af5c7b9db4b3eb011752c0fa15d3c9d4930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Tue, 28 May 2024 18:25:22 +0800 Subject: [PATCH] =?UTF-8?q?llama=E6=94=AF=E6=8C=81=E4=B8=AD=E9=97=B4?= =?UTF-8?q?=E7=BB=93=E6=9E=9C=E8=AE=BE=E7=BD=AE=E4=B8=BAfloat16?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/devices/cpu/cpudevice.cpp | 58 +++++++++++++++++++++++------------ src/models/basellm.cpp | 11 +++++-- src/models/llama.cpp | 9 ++++-- 3 files changed, 54 insertions(+), 24 deletions(-) diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index f88572b5..df96da52 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -3691,42 +3691,62 @@ namespace fastllm { } struct MultiThreadLlamaRotatePosition2DFloatOp : MultiThreadBaseOp { + DataType dataType; float *data, *positionIds, *sinData, *cosData; int bs, len, n, m, stride, spatial, posDim, rotaryDim; int st, end; MultiThreadLlamaRotatePosition2DFloatOp - (float *data, float *positionIds, float *sinData, float *cosData, + (DataType dataType, float *data, float *positionIds, float *sinData, float *cosData, int bs, int len, int n, int m, int stride, int spatial, int posDim, int rotaryDim, int st, int end) : - data(data), positionIds(positionIds), sinData(sinData), cosData(cosData), + dataType(dataType), data(data), positionIds(positionIds), sinData(sinData), cosData(cosData), bs(bs), len(len), n(n), m(m), stride(stride), spatial(spatial), posDim(posDim), rotaryDim(rotaryDim), st(st), end(end) {} void Run() { - for (int idx = st; idx < end; idx++) { - int b = idx / len; - int l = idx % len; - int index = (int) ((float *) positionIds)[b * posDim + l]; - float *sin = ((float *) sinData) + stride * index; - float *cos = ((float *) cosData) + stride * index; - float *d = (float *) data + (b * len + l) * spatial; - for (int i = 0; i < n; i++) { - for (int j = 0; j < rotaryDim && j < m / 2; j++) { - float a = d[j], b = d[j + m / 2]; - d[j] = a * cos[j] - b * sin[j]; - d[j + m / 2] = a * sin[j] + b * cos[j]; + if (dataType == DataType::FLOAT32) { + for (int idx = st; idx < end; idx++) { + int b = idx / len; + int l = idx % len; + int index = (int) ((float *) positionIds)[b * posDim + l]; + float *sin = ((float *) sinData) + stride * index; + float *cos = ((float *) cosData) + stride * index; + float *d = (float *) data + (b * len + l) * spatial; + for (int i = 0; i < n; i++) { + for (int j = 0; j < rotaryDim && j < m / 2; j++) { + float a = d[j], b = d[j + m / 2]; + d[j] = a * cos[j] - b * sin[j]; + d[j + m / 2] = a * sin[j] + b * cos[j]; + } + d += m; + } + } + } else { + for (int idx = st; idx < end; idx++) { + int b = idx / len; + int l = idx % len; + int index = (int) ((float *) positionIds)[b * posDim + l]; + float *sin = ((float *) sinData) + stride * index; + float *cos = ((float *) cosData) + stride * index; + uint16_t *d = (uint16_t *) data + (b * len + l) * spatial; + for (int i = 0; i < n; i++) { + for (int j = 0; j < rotaryDim && j < m / 2; j++) { + float a = fp16tofp32.dict[d[j]], b = fp16tofp32.dict[d[j + m / 2]]; + d[j] = float_to_half(a * cos[j] - b * sin[j]); + d[j + m / 2] = float_to_half(a * sin[j] + b * cos[j]); + } + d += m; } - d += m; } } } }; - static void RunMultiThreadLlamaRotatePosition2DFloat(float *data, float *positionIds, float *sinData, float *cosData, + static void RunMultiThreadLlamaRotatePosition2DFloat(DataType dataType, float *data, float *positionIds, float *sinData, float *cosData, int bs, int len, int n, int m, int stride, int spatial, int posDim, int rotaryDim, AliveThreadPool *pool) { if (bs * len == 1) { - (MultiThreadLlamaRotatePosition2DFloatOp(data, positionIds, sinData, cosData, bs, len, n, m, stride, spatial, posDim, rotaryDim, 0, bs * len)).Run(); + (MultiThreadLlamaRotatePosition2DFloatOp(dataType, data, positionIds, sinData, cosData, bs, len, n, m, stride, spatial, posDim, rotaryDim, 0, bs * len)).Run(); return; } @@ -3737,7 +3757,7 @@ namespace fastllm { for (int i = 0; i < threadNum; i++) { int end = (i == threadNum - 1 ? (bs * len) : cur + per + (cur + per * (threadNum - i) < (bs * len))); ops.push_back(new MultiThreadLlamaRotatePosition2DFloatOp( - data, positionIds, sinData, cosData, bs, len, n, m, stride, spatial, posDim, rotaryDim, cur, end)); + dataType, data, positionIds, sinData, cosData, bs, len, n, m, stride, spatial, posDim, rotaryDim, cur, end)); cur = end; } for (int i = 0; i < threadNum; i++) { @@ -3761,7 +3781,7 @@ namespace fastllm { int spatial = data.Count(2); int n = data.dims[2], m = data.dims[3]; int stride = (int)sinData.dims[1]; - RunMultiThreadLlamaRotatePosition2DFloat((float*)data.cpuData, (float*)positionIds.cpuData, + RunMultiThreadLlamaRotatePosition2DFloat(data.dataType, (float*)data.cpuData, (float*)positionIds.cpuData, (float*)sinData.cpuData, (float*)cosData.cpuData, bs, len, n, m, stride, spatial, positionIds.dims.back(), rotaryDim, GetAlivePool()); } diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index 9b90f2de..71fd7f79 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -115,6 +115,7 @@ namespace fastllm { int add_special_tokens = generationConfig.add_special_tokens? 1: 0; FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, inputIds, attentionMask, positionIds); + ToDataType(attentionMask, this->dataType); while (true) { auto st = std::chrono::system_clock::now(); int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens); @@ -149,6 +150,7 @@ namespace fastllm { inputTokens[0] = std::vector {(float)ret}; FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, inputIds, attentionMask, positionIds); + ToDataType(attentionMask, this->dataType); if (index == generationConfig.output_token_limit) { break; } @@ -230,6 +232,7 @@ namespace fastllm { LastTokensManager tokensManager (batch, generationConfig.last_n); std::vector isEnding = std::vector (batch, false); FillLLMInputsBatch(inputTokens, params, inputIds, attentionMask, positionIds); + ToDataType(attentionMask, this->dataType); while (true) { auto st = std::chrono::system_clock::now(); std::vector ret = ForwardBatch(batch, inputIds, attentionMask, positionIds, pastKeyValues, @@ -295,6 +298,7 @@ namespace fastllm { index++; params[0]["index"] = index; FillLLMInputsBatch(inputTokens, params, inputIds, attentionMask, positionIds); + ToDataType(attentionMask, this->dataType); // printf("len = %d, spend %f s.\n", len, GetSpan(st, std::chrono::system_clock::now())); if (index == generationConfig.output_token_limit) { @@ -636,8 +640,9 @@ printf("%d / %d\n", endingCount, batch); for (int i: it.second->currentTokens) { tokens[0].push_back(i); } - model->FillLLMInputs(tokens, it.second->intParams, inputIds, attentionMask, - curPositionIds); + model->FillLLMInputs(tokens, it.second->intParams, inputIds, attentionMask, curPositionIds); + ToDataType(attentionMask, model->dataType); + seqLens.push_back(inputIds.Count(0)); for (int i = 0; i < inputIds.Count(0); i++) { ids.push_back(((float *) inputIds.cpuData)[i]); @@ -870,7 +875,7 @@ printf("tot = %d\n", tot); if (dataType == DataType::FLOAT32) { } else if (dataType == DataType::FLOAT16) { - AssertInFastLLM(this->model_type == "chatglm", + AssertInFastLLM(this->model_type == "chatglm" || this->model_type == "llama", this->model_type + " doesn't support float16"); } else { ErrorInFastLLM("SetDataType Error: datatype should be float32 or float16"); diff --git a/src/models/llama.cpp b/src/models/llama.cpp index e05b1a0e..33684196 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -259,6 +259,8 @@ namespace fastllm { Data* cosDataPtr = &cosData; Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); + ToDataType(hiddenStates, this->dataType); + int seqlen = hiddenStates.dims[1]; for (int i = 0; i < block_cnt; i++) { ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); @@ -436,7 +438,7 @@ namespace fastllm { auto &hiddenStates = *lastHiddenStates; RMSNorm(hiddenStates, weight["model.norm.weight"], rms_norm_eps, hiddenStates); Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); - + ToDataType(logits, DataType::FLOAT32); if (generationConfig.output_logits && retLogits != nullptr) { int size = logits.dims.back(); logits.ToDevice(DataDevice::CPU); @@ -526,6 +528,8 @@ namespace fastllm { CatBatch(contexts, 1, allPositionIds); Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); + ToDataType(hiddenStates, this->dataType); + int seqlen = hiddenStates.dims[1]; for (int i = 0; i < block_cnt; i++) { ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); @@ -602,7 +606,7 @@ namespace fastllm { fastllm::LlamaRotatePosition2D(k, allPositionIds, *sinDataPtr, *cosDataPtr, rotary_dim); } - Data attenOutput = Data(DataType::FLOAT32); + Data attenOutput = Data(this->dataType); int total = 0; if (all1 && batch > 1) { q.Reshape({-1, q.dims[2], q.dims[3]}); @@ -767,6 +771,7 @@ namespace fastllm { RMSNorm(hiddenStates, weight["model.norm.weight"], rms_norm_eps, hiddenStates); Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); + ToDataType(logits, DataType::FLOAT32); std::vector lastRet; int total = 0;