Skip to content

Commit

Permalink
llama支持中间结果设置为float16
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed May 28, 2024
1 parent 177e9cf commit 37853af
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 24 deletions.
58 changes: 39 additions & 19 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3691,42 +3691,62 @@ namespace fastllm {
}

struct MultiThreadLlamaRotatePosition2DFloatOp : MultiThreadBaseOp {
DataType dataType;
float *data, *positionIds, *sinData, *cosData;
int bs, len, n, m, stride, spatial, posDim, rotaryDim;
int st, end;

MultiThreadLlamaRotatePosition2DFloatOp
(float *data, float *positionIds, float *sinData, float *cosData,
(DataType dataType, float *data, float *positionIds, float *sinData, float *cosData,
int bs, int len, int n, int m, int stride, int spatial, int posDim, int rotaryDim,
int st, int end) :
data(data), positionIds(positionIds), sinData(sinData), cosData(cosData),
dataType(dataType), data(data), positionIds(positionIds), sinData(sinData), cosData(cosData),
bs(bs), len(len), n(n), m(m), stride(stride), spatial(spatial), posDim(posDim), rotaryDim(rotaryDim),
st(st), end(end) {}

void Run() {
for (int idx = st; idx < end; idx++) {
int b = idx / len;
int l = idx % len;
int index = (int) ((float *) positionIds)[b * posDim + l];
float *sin = ((float *) sinData) + stride * index;
float *cos = ((float *) cosData) + stride * index;
float *d = (float *) data + (b * len + l) * spatial;
for (int i = 0; i < n; i++) {
for (int j = 0; j < rotaryDim && j < m / 2; j++) {
float a = d[j], b = d[j + m / 2];
d[j] = a * cos[j] - b * sin[j];
d[j + m / 2] = a * sin[j] + b * cos[j];
if (dataType == DataType::FLOAT32) {
for (int idx = st; idx < end; idx++) {
int b = idx / len;
int l = idx % len;
int index = (int) ((float *) positionIds)[b * posDim + l];
float *sin = ((float *) sinData) + stride * index;
float *cos = ((float *) cosData) + stride * index;
float *d = (float *) data + (b * len + l) * spatial;
for (int i = 0; i < n; i++) {
for (int j = 0; j < rotaryDim && j < m / 2; j++) {
float a = d[j], b = d[j + m / 2];
d[j] = a * cos[j] - b * sin[j];
d[j + m / 2] = a * sin[j] + b * cos[j];
}
d += m;
}
}
} else {
for (int idx = st; idx < end; idx++) {
int b = idx / len;
int l = idx % len;
int index = (int) ((float *) positionIds)[b * posDim + l];
float *sin = ((float *) sinData) + stride * index;
float *cos = ((float *) cosData) + stride * index;
uint16_t *d = (uint16_t *) data + (b * len + l) * spatial;
for (int i = 0; i < n; i++) {
for (int j = 0; j < rotaryDim && j < m / 2; j++) {
float a = fp16tofp32.dict[d[j]], b = fp16tofp32.dict[d[j + m / 2]];
d[j] = float_to_half(a * cos[j] - b * sin[j]);
d[j + m / 2] = float_to_half(a * sin[j] + b * cos[j]);
}
d += m;
}
d += m;
}
}
}
};

static void RunMultiThreadLlamaRotatePosition2DFloat(float *data, float *positionIds, float *sinData, float *cosData,
static void RunMultiThreadLlamaRotatePosition2DFloat(DataType dataType, float *data, float *positionIds, float *sinData, float *cosData,
int bs, int len, int n, int m, int stride, int spatial, int posDim, int rotaryDim, AliveThreadPool *pool) {
if (bs * len == 1) {
(MultiThreadLlamaRotatePosition2DFloatOp(data, positionIds, sinData, cosData, bs, len, n, m, stride, spatial, posDim, rotaryDim, 0, bs * len)).Run();
(MultiThreadLlamaRotatePosition2DFloatOp(dataType, data, positionIds, sinData, cosData, bs, len, n, m, stride, spatial, posDim, rotaryDim, 0, bs * len)).Run();
return;
}

Expand All @@ -3737,7 +3757,7 @@ namespace fastllm {
for (int i = 0; i < threadNum; i++) {
int end = (i == threadNum - 1 ? (bs * len) : cur + per + (cur + per * (threadNum - i) < (bs * len)));
ops.push_back(new MultiThreadLlamaRotatePosition2DFloatOp(
data, positionIds, sinData, cosData, bs, len, n, m, stride, spatial, posDim, rotaryDim, cur, end));
dataType, data, positionIds, sinData, cosData, bs, len, n, m, stride, spatial, posDim, rotaryDim, cur, end));
cur = end;
}
for (int i = 0; i < threadNum; i++) {
Expand All @@ -3761,7 +3781,7 @@ namespace fastllm {
int spatial = data.Count(2);
int n = data.dims[2], m = data.dims[3];
int stride = (int)sinData.dims[1];
RunMultiThreadLlamaRotatePosition2DFloat((float*)data.cpuData, (float*)positionIds.cpuData,
RunMultiThreadLlamaRotatePosition2DFloat(data.dataType, (float*)data.cpuData, (float*)positionIds.cpuData,
(float*)sinData.cpuData, (float*)cosData.cpuData, bs, len, n, m, stride, spatial,
positionIds.dims.back(), rotaryDim, GetAlivePool());
}
Expand Down
11 changes: 8 additions & 3 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ namespace fastllm {
int add_special_tokens = generationConfig.add_special_tokens? 1: 0;
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}},
inputIds, attentionMask, positionIds);
ToDataType(attentionMask, this->dataType);
while (true) {
auto st = std::chrono::system_clock::now();
int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens);
Expand Down Expand Up @@ -149,6 +150,7 @@ namespace fastllm {
inputTokens[0] = std::vector<float> {(float)ret};
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}},
inputIds, attentionMask, positionIds);
ToDataType(attentionMask, this->dataType);
if (index == generationConfig.output_token_limit) {
break;
}
Expand Down Expand Up @@ -230,6 +232,7 @@ namespace fastllm {
LastTokensManager tokensManager (batch, generationConfig.last_n);
std::vector <bool> isEnding = std::vector <bool> (batch, false);
FillLLMInputsBatch(inputTokens, params, inputIds, attentionMask, positionIds);
ToDataType(attentionMask, this->dataType);
while (true) {
auto st = std::chrono::system_clock::now();
std::vector <int> ret = ForwardBatch(batch, inputIds, attentionMask, positionIds, pastKeyValues,
Expand Down Expand Up @@ -295,6 +298,7 @@ namespace fastllm {
index++;
params[0]["index"] = index;
FillLLMInputsBatch(inputTokens, params, inputIds, attentionMask, positionIds);
ToDataType(attentionMask, this->dataType);
// printf("len = %d, spend %f s.\n", len, GetSpan(st, std::chrono::system_clock::now()));

if (index == generationConfig.output_token_limit) {
Expand Down Expand Up @@ -636,8 +640,9 @@ printf("%d / %d\n", endingCount, batch);
for (int i: it.second->currentTokens) {
tokens[0].push_back(i);
}
model->FillLLMInputs(tokens, it.second->intParams, inputIds, attentionMask,
curPositionIds);
model->FillLLMInputs(tokens, it.second->intParams, inputIds, attentionMask, curPositionIds);
ToDataType(attentionMask, model->dataType);

seqLens.push_back(inputIds.Count(0));
for (int i = 0; i < inputIds.Count(0); i++) {
ids.push_back(((float *) inputIds.cpuData)[i]);
Expand Down Expand Up @@ -870,7 +875,7 @@ printf("tot = %d\n", tot);
if (dataType == DataType::FLOAT32) {

} else if (dataType == DataType::FLOAT16) {
AssertInFastLLM(this->model_type == "chatglm",
AssertInFastLLM(this->model_type == "chatglm" || this->model_type == "llama",
this->model_type + " doesn't support float16");
} else {
ErrorInFastLLM("SetDataType Error: datatype should be float32 or float16");
Expand Down
9 changes: 7 additions & 2 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ namespace fastllm {
Data* cosDataPtr = &cosData;

Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates);
ToDataType(hiddenStates, this->dataType);

int seqlen = hiddenStates.dims[1];
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
Expand Down Expand Up @@ -436,7 +438,7 @@ namespace fastllm {
auto &hiddenStates = *lastHiddenStates;
RMSNorm(hiddenStates, weight["model.norm.weight"], rms_norm_eps, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);

ToDataType(logits, DataType::FLOAT32);
if (generationConfig.output_logits && retLogits != nullptr) {
int size = logits.dims.back();
logits.ToDevice(DataDevice::CPU);
Expand Down Expand Up @@ -526,6 +528,8 @@ namespace fastllm {
CatBatch(contexts, 1, allPositionIds);

Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates);
ToDataType(hiddenStates, this->dataType);

int seqlen = hiddenStates.dims[1];
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
Expand Down Expand Up @@ -602,7 +606,7 @@ namespace fastllm {
fastllm::LlamaRotatePosition2D(k, allPositionIds, *sinDataPtr, *cosDataPtr, rotary_dim);
}

Data attenOutput = Data(DataType::FLOAT32);
Data attenOutput = Data(this->dataType);
int total = 0;
if (all1 && batch > 1) {
q.Reshape({-1, q.dims[2], q.dims[3]});
Expand Down Expand Up @@ -767,6 +771,7 @@ namespace fastllm {

RMSNorm(hiddenStates, weight["model.norm.weight"], rms_norm_eps, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
ToDataType(logits, DataType::FLOAT32);
std::vector <int> lastRet;
int total = 0;

Expand Down

0 comments on commit 37853af

Please sign in to comment.