diff --git a/.gitignore b/.gitignore index bcc9c47..de26454 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,5 @@ token /localtest/ /third_party/tfacc/driver/tfacc2/result /.chainlit -/.files \ No newline at end of file +/.files +*.o \ No newline at end of file diff --git a/include/graph.h b/include/graph.h index 8b5dc3a..f8606ed 100644 --- a/include/graph.h +++ b/include/graph.h @@ -79,10 +79,13 @@ namespace fastllm { // 执行计算图 void RunComputeGraph (const ComputeGraph &graph, - const std::map &deviceMap, - std::map inputs, - std::map weights, - std::map outputs); + const std::map &deviceMap, + const std::map &inputs, + const std::map &weights, + const std::map &outputs, + std::vector > &pastKeys, + std::vector > &pastValues, + std::vector &masks); } #endif //FASTLLM_GRAPH_H diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index fefd18f..b1b8912 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -1030,7 +1030,8 @@ namespace fastllm { AssertInFastLLM(weight.dims.size() == 2, "Embedding's weight's dim should be 2.\n"); AssertInFastLLM(weight.dataType == DataType::FLOAT32 || - weight.dataType == DataType::BFLOAT16, "Embedding's weight's type should be float32 or bfloat16.\n"); + weight.dataType == DataType::FLOAT16 || + weight.dataType == DataType::BFLOAT16, "Embedding's weight's type should be float32 or float16 or bfloat16.\n"); AssertInFastLLM(input.dataType == DataType::FLOAT32 || input.dataType == DataType::FLOAT16, "Embedding's input's type should be float32 or float16.\n"); @@ -1041,6 +1042,9 @@ namespace fastllm { dims.push_back(embSize); output.dataType = input.dataType; + if (weight.dataType == DataType::FLOAT16) { + output.dataType = DataType::FLOAT16; + } output.Resize(dims); } diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 5e0947c..071670c 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -286,8 +286,8 @@ void GpuQK(half *q, half *k, half *qk, int qlen, int klen, int dim, float scale, HalfFC <<>> (q, k, qk, qlen, dim, klen, (half)scale, base); } -template -__global__ void FastllmCudaFloatEmbeddingKernel(float *input, float *weight, float *output, int embSize) { +template +__global__ void FastllmCudaFloatEmbeddingKernel(float *input, T *weight, T *output, int embSize) { input += blockIdx.x; output += blockIdx.x * embSize; int token = (int)(input[0] + 1e-5); @@ -467,8 +467,8 @@ __global__ void FastllmSiluKernel(float* a, float *b, int len) { __global__ void FastllmSiluKernel(half* a, half *b, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { - float x = __half2float(a[idx]); - b[idx] = __float2half(x / (1.0 + expf(-x))); + half x = a[idx]; + b[idx] = __hdiv(x, __hadd(__float2half(1.0), hexp(-x))); } } @@ -477,7 +477,7 @@ __global__ void FastllmSwigluKernel(float* a, float *b, int len, int spatial, in if (idx < len) { int id = idx / mid * spatial + idx % mid; float x = a[id], y = a[id + mid]; - b[idx] = (x / (1.0 + expf(-x))) * y; + b[idx] = (x / (1.0f + expf(-x))) * y; } } @@ -3134,13 +3134,13 @@ void FastllmCudaMemcpyBetweenDevices(int dstId, void *dst, int srcId, void *src, delete[] cpuData; } checkCudaErrors("Error: CUDA error when copy Between GPUs!", state); - //cudaDeviceSynchronize(); + DeviceSync(); } void FastllmCudaMemcpy2DDeviceToDevice(void * dst, size_t dpitch, const void * src, size_t spitch, size_t width, size_t height) { cudaMemcpy2D(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice); - //cudaDeviceSynchronize(); + DeviceSync(); } template @@ -3223,7 +3223,7 @@ bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output) { int len = input.Count(0); float *cudaInput = (float *) FastllmCudaPrepareInput(input); float *cudaOutput = (float *) FastllmCudaPrepareOutput(output); - int threadPerBlock = std::min(256, len); + int threadPerBlock = std::min(1024, len); if (input.dataType == fastllm::DataType::FLOAT32) { FastllmSiluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len); } else if (input.dataType == fastllm::DataType::FLOAT16) { @@ -3240,7 +3240,7 @@ bool FastllmCudaSwiglu(const fastllm::Data &input, fastllm::Data &output) { float *cudaOutput = (float *) FastllmCudaPrepareOutput(output); int spatial = input.Count(input.dims.size() - 1), mid = spatial / 2; - int threadPerBlock = std::min(256, len); + int threadPerBlock = std::min(1024, len); if (input.dataType == fastllm::DataType::FLOAT32) { FastllmSwigluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len, spatial, mid); } else if (input.dataType == fastllm::DataType::FLOAT16) { @@ -3274,7 +3274,7 @@ bool FastllmCudaAddTo(fastllm::Data &input0, const fastllm::Data &input1, float float *cudaData = (float *) FastllmCudaPrepareInput(input0); float *input1Data = (float *) FastllmCudaPrepareInput(input1); - int threadPerBlock = std::min(256, len); + int threadPerBlock = std::min(1024, len); if (input0.dataType == fastllm::DataType::FLOAT32) { FastllmAddToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len); } else if (input0.dataType == fastllm::DataType::FLOAT16) { @@ -3584,24 +3584,28 @@ bool FastllmCudaPermute(fastllm::Data &input, const std::vector &axis) { } FastllmCudaFree(tempData); + DeviceSync(); return true; } bool FastllmFloatToHalf(void *a, void *b, int len) { int threadPerBlock = std::min(256, len); FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((float*)a, (half*)b, len); + DeviceSync(); return true; } bool FastllmHalfToFloat(void *a, void *b, int len) { int threadPerBlock = std::min(256, len); FastllmCudaHalf2FloatKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)a, (float*)b, len); + DeviceSync(); return true; } bool FastllmBF16ToFloat(void *a, void *b, int len) { int threadPerBlock = std::min(256, len); FastllmCudaBF162FloatKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint16_t*)a, (float*)b, len); + DeviceSync(); return true; } @@ -3616,6 +3620,10 @@ bool FastllmCudaEmbedding(const fastllm::Data &input, const fastllm::Data &weigh float *outputData = (float *) dstOutputData; float *weightData = (float *) weight.cudaData; FastllmCudaFloatEmbeddingKernel <128> <<>> (inputData, weightData, outputData, embSize); + } else if (weight.dataType == fastllm::DataType::FLOAT16) { + half *outputData = (half *) dstOutputData; + half *weightData = (half *) weight.cudaData; + FastllmCudaFloatEmbeddingKernel <128> <<>> (inputData, weightData, outputData, embSize); } else if (weight.dataType == fastllm::DataType::BFLOAT16) { std::vector cpuInputData = std::vector (inputLen, 0.0f); FastllmCudaCopyFromDeviceToHost(cpuInputData.data(), inputData, cpuInputData.size() * sizeof(float)); @@ -3627,8 +3635,11 @@ bool FastllmCudaEmbedding(const fastllm::Data &input, const fastllm::Data &weigh FastllmBF16ToFloat(outputData + i * embSize, weightData + token * embSize, embSize); } } + } else { + } + DeviceSync(); return true; } @@ -4137,6 +4148,105 @@ bool FastllmCudaRepeatPenalty (fastllm::Data &input, fastllm::Data &penalty, fas template bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Data **v, fastllm::Data **mask, fastllm::Data **output, int group, float scale, int batch) { + if (false) { + half beta = __float2half_rn(0.0f), one = __float2half_rn(1.0f), hscale = __float2half_rn(scale); + int q0 = q[0]->dims[0], q1 = q[0]->dims[1], q2 = q[0]->dims[2], k0 = k[0]->dims[0], k1 = k[0]->dims[1], v2 = v[0]->dims[2]; + for (int i = 0; i < batch; i++) { + q1 = std::max(q1, q[i]->dims[1]); + k1 = std::max(k1, k[i]->dims[1]); + } + + half *allKeys = (half*) FastllmCudaMalloc(batch * k0 * k1 * q2 * sizeof(half)); + half *allValues = (half*) FastllmCudaMalloc(batch * k0 * k1 * v2 * sizeof(half)); + + std::vector dsts, srcs; + std::vector dpitchs, spitchs, widths, heights; + for (int i = 0; i < batch; i++) { + dsts.push_back((uint8_t *) (allKeys + i * k0 * k1 * q2)); + dpitchs.push_back(k1 * q2 * sizeof(half)); + srcs.push_back(k[i]->cudaData); + spitchs.push_back(k[i]->strides[0] * sizeof(half)); + widths.push_back(k[i]->dims[1] * q2 * sizeof(half)); + heights.push_back(k0); + + dsts.push_back((uint8_t *) (allValues + i * k0 * k1 * v2)); + dpitchs.push_back(k1 * v2 * sizeof(half)); + srcs.push_back(v[i]->cudaData); + spitchs.push_back(v[i]->strides[0] * sizeof(half)); + widths.push_back(v[i]->dims[1] * v2 * sizeof(half)); + heights.push_back(k0); + } + FastllmCudaMemcpy2DDeviceToDeviceBatch(dsts.data(), dpitchs.data(), srcs.data(), spitchs.data(), widths.data(), heights.data(), dsts.size()); +/* + for (int i = 0; i < batch; i++) { + cudaMemcpy2D( + allKeys + i * k0 * k1 * q2, k1 * q2 * sizeof(half), + k[i]->cudaData, k[i]->strides[0] * sizeof(half), + k[i]->dims[1] * q2 * sizeof(half), k0, + cudaMemcpyDeviceToDevice + ); + cudaMemcpy2D( + allValues + i * k0 * k1 * v2, k1 * v2 * sizeof(half), + v[i]->cudaData, v[i]->strides[0] * sizeof(half), + v[i]->dims[1] * v2 * sizeof(half), k0, + cudaMemcpyDeviceToDevice + ); + } +*/ + half *qd = (half*)q[0]->cudaData; + half *od = (half*)output[0]->cudaData; + half *qk = (half *) FastllmCudaMalloc(batch * q0 * q1 * k1 * sizeof(half)); + half *temp = (half *) FastllmCudaMalloc(batch * q0 * q1 * k1 * sizeof(half)); + auto fastllmCublasHandle = getFastllmCublasHandle(); + cublasStatus_t status; + + status = cublasHgemmStridedBatched(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + k1, q1 * group, q2, &hscale, + allKeys, q2, k1 * q2, + qd, q2, group * q1 * q2, + &beta, + qk, k1, k1 * q1 * group, batch * q0 / group); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("status = %d\n", (int) status); + printf("Error: cublas error during MatMulTransB in Attention operator.\n"); + throw ("cublas error"); + exit(0); + } + + int outer = batch * q0 * q1; + if (k1 < 8) { + FastllmSoftmaxKernelInner1<1> <<< outer, 1 >>>(qk, temp, outer, k1); + } else if (k1 < 64) { + FastllmSoftmaxKernelInner1<8> <<< outer, 8 >>>(qk, temp, outer, k1); + } else if (k1 < 512) { + FastllmSoftmaxKernelInner1<64> <<< outer, 64 >>>(qk, temp, outer, k1); + } else { + FastllmSoftmaxKernelInner1<256> <<< outer, 256 >>>(qk, temp, outer, k1); + } + + status = cublasHgemmStridedBatched(fastllmCublasHandle, + CUBLAS_OP_N, CUBLAS_OP_N, + v2, q1 * group, k1, &one, + allValues, v2, k1 * v2, + temp, k1, k1 * q1 * group, + &beta, + od, v2, v2 * q1 * group, batch * q0 / group); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("status = %d\n", (int) status); + printf("Error: cublas error during MatMul in Attention operator.\n"); + throw ("cublas error"); + exit(0); + } + + FastllmCudaFree(allKeys); + FastllmCudaFree(allValues); + FastllmCudaFree(qk); + FastllmCudaFree(temp); + DeviceSync(); + return true; + } + int k0 = k[0]->dims[0]; size_t memSum = 0; for (int b = 0; b < batch; b++) { diff --git a/src/fastllm.cpp b/src/fastllm.cpp index f011fa9..9f04556 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -334,6 +334,7 @@ namespace fastllm { this->Resize(ori.dims); this->Allocate(); } else { + this->expansionDims.clear(); this->Resize(ori.dims); this->Allocate(); } diff --git a/src/graph.cpp b/src/graph.cpp index 7b9420b..e393690 100644 --- a/src/graph.cpp +++ b/src/graph.cpp @@ -76,6 +76,18 @@ namespace fastllm { } ComputeGraphNode input(op.datas["input"]), weight(mergeWeightName), bias(mergeBiasName), mid(outputName); graph.Linear(input, weight, bias, mid); + + // 如果后面接的silu + mul, 那么合并成swiglu + if (j == i + 1) { + if (j + 2 < ops.size() && + ops[j + 1].type == "Silu" && ops[j + 1].datas["input"] == ops[j + 1].datas["output"] && ops[j + 1].datas["input"] == ops[i].datas["output"] && + ops[j + 2].type == "MulTo" && ops[j + 2].datas["input0"] == ops[i].datas["output"] && ops[j + 2].datas["input1"] == ops[j].datas["output"]) { + ComputeGraphNode swigluOutput(ops[i].datas["output"]); + graph.Swiglu(mid, swigluOutput); + i = j + 2; + continue; + } + } offset = 0; for (int l = i; l <= j; l++) { ComputeGraphNode output(ops[l].datas["output"]); @@ -91,17 +103,43 @@ namespace fastllm { } } + void ParseIdsByDots(const std::string &s, std::vector &ids) { + ids.clear(); + int now = 0; + for (int i = 0; i < s.size(); i++) { + if (s[i] == '.') { + if (now >= 0) { + ids.push_back(now); + } + now = 0; + } else if (now >= 0 && s[i] >= '0' && s[i] <= '9') { + now = now * 10 + s[i] - '0'; + } else { + now = -1; + } + } + if (now >= 0) { + ids.push_back(now); + } + } + void RunComputeGraph (const ComputeGraph &graph, const std::map &deviceMap, - std::map inputs, - std::map weights, - std::map outputs) { + const std::map &inputs, + const std::map &weights, + const std::map &outputs, + std::vector > &pastKeys, + std::vector > &pastValues, + std::vector &masks) { Executor &excutor = *((Executor*)GetExecutor()); - std::map tempDatas; - std::map allDatas; + std::unordered_map tempDatas; + std::unordered_map allDatas; + std::vector ids; + std::vector curContextLayer; + std::vector curQs, curKs, curVs, curOutputs; for (auto &it : inputs) { - allDatas[it.first] = it.second; + allDatas[it.first] = it.second; } for (auto &it : weights) { allDatas[it.first] = it.second; @@ -116,7 +154,6 @@ namespace fastllm { } } Data emptyData; - for (int i = 0; i < graph.ops.size(); i++) { auto &op = graph.ops[i]; // 一些没实现的算子 @@ -124,8 +161,10 @@ namespace fastllm { exit(0); } else if (op.type == "Print") { auto data = allDatas[op.datas.find("input")->second]; + auto oriDevice = data->dataDevice; data->ToDevice(DataDevice::CPU); data->Print(); + data->ToDevice(oriDevice); } else if (op.type == "DataTypeAs") { auto input = allDatas[op.datas.find("input")->second]; DataType dataType = allDatas[op.datas.find("input1")->second]->dataType; @@ -149,6 +188,9 @@ namespace fastllm { dims.push_back(headDim); data->Reshape(dims); } else if (op.type == "FusedAttention") { + ParseIdsByDots(op.datas.find("k")->second, ids); + int layerId = ids[0]; + std::vector seqLens; { auto data = allDatas[op.datas.find("seqLens")->second]; @@ -156,7 +198,6 @@ namespace fastllm { seqLens.push_back(((int*)data->cpuData)[i]); } } - if (seqLens.size() == 1) { { std::vector axis = {0, 2, 1, 3}; @@ -177,7 +218,7 @@ namespace fastllm { int unitLen = op.intParams.find("unitLen")->second; for (int i = 0; i < 2; i++) { - auto cache = allDatas[op.datas.find(i == 0 ? "k" : "v")->second]; + auto cache = i == 0 ? pastKeys[layerId][0] : pastValues[layerId][0]; auto cur = allDatas[op.datas.find(i == 0 ? "curk" : "curv")->second]; while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1])) @@ -200,6 +241,10 @@ namespace fastllm { for (auto &it : op.datas) { dataDict[it.first] = allDatas[it.second]; } + dataDict["k"] = pastKeys[layerId][0]; + dataDict["v"] = pastValues[layerId][0]; + dataDict["mask"] = masks[0]; + excutor.Run("Attention", dataDict, op.floatParams, op.intParams); { auto output = allDatas[op.datas.find("output")->second]; @@ -221,16 +266,18 @@ namespace fastllm { } } else { int batch = seqLens.size(), total = 0; - bool all1 = true; + bool all1 = true, allSame = true; for (int i = 0; i < seqLens.size(); i++) { if (seqLens[i] != 1) { all1 = false; - break; } + if (seqLens[i] != seqLens[0]) { + allSame = false; + } + total += seqLens[i]; } - + int paddingLen = allDatas[op.datas.find("q")->second]->dims[1]; if (all1) { - std::vector curQs, curKs, curVs, curOutputs; curQs.resize(batch); curKs.resize(batch); curVs.resize(batch); @@ -262,13 +309,17 @@ namespace fastllm { curVs[b].FakeFrom(v, b * v.strides[0] * v.unitSize); } total = batch; - int unitLen = op.intParams.find("unitLen")->second; + std::vector qs, contexts; + qs.resize(batch); + contexts.resize(batch); + for (int i = 0; i < 2; i++) { std::vector caches, curs; for (int b = 0; b < batch; b++) { - auto cache = allDatas[op.datas.find(i == 0 ? "k" : "v")->second + "." + std::to_string(b)]; + auto cache = i == 0 ? pastKeys[layerId][b] : pastValues[layerId][b]; auto cur = i == 0 ? &curKs[b] : &curVs[b]; + bool needExpansion = false; while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1])) || (cache->dims.size() > 0 && cache->dims[1] + cur->dims[1] > cache->expansionDims[1])) { std::vector newDims; @@ -279,38 +330,103 @@ namespace fastllm { newDims[1] += ((cur->dims[1] - 1) / unitLen + 1) * unitLen; } cache->Expansion(newDims); + needExpansion = true; } caches.push_back(cache); curs.push_back(cur); } CatDirectBatch(caches, curs, 1); } - auto &attenOutput = *allDatas[op.datas.find("output")->second]; attenOutput.dataType = q.dataType; attenOutput.ToDevice(q.dataDevice); attenOutput.Resize({1, batch, embed_dim}); attenOutput.Allocate(); - std::vector curContextLayer; - std::vector qs, keys, values, masks, contexts; curContextLayer.resize(batch); - qs.resize(batch); - keys.resize(batch); - values.resize(batch); - masks.resize(batch); - contexts.resize(batch); - for (int b = 0; b < batch; b++) { - std::string sb = "." + std::to_string(b); qs[b] = (&curQs[b]); - keys[b] = allDatas[op.datas.find("k")->second + sb]; - values[b] = allDatas[op.datas.find("v")->second + sb]; - masks[b] = allDatas[op.datas.find("mask")->second + sb]; curContextLayer[b].FakeFrom(attenOutput, b * embed_dim * attenOutput.unitSize); contexts[b] = (&curContextLayer[b]); } - AttentionBatch(qs, keys, values, masks, contexts, qs[0]->dims[0] / values[0]->dims[0], op.floatParams.find("scale")->second, 1); + AttentionBatch(qs, pastKeys[layerId], pastValues[layerId], masks, contexts, qs[0]->dims[0] / pastValues[layerId][0]->dims[0], op.floatParams.find("scale")->second, 1); + } else if (total != paddingLen || allSame) { + int maxLen = seqLens[0]; + for (int i = 0; i < seqLens.size(); i++) { + maxLen = std::max(maxLen, seqLens[i]); + } + auto &q = *allDatas[op.datas.find("q")->second]; + auto &k = *allDatas[op.datas.find("curk")->second]; + auto &v = *allDatas[op.datas.find("curv")->second]; + + std::vector curKs, curVs; + int head_dim = allDatas[op.datas.find("q")->second]->dims.back(); + curKs.resize(batch); + curVs.resize(batch); + PermuteSelf(k, {0, 2, 1, 3}); + PermuteSelf(v, {0, 2, 1, 3}); + k.Reshape({-1, k.dims[2], k.dims[3]}); + v.Reshape({-1, v.dims[2], v.dims[3]}); + for (int b = 0; b < batch; b++) { + excutor.Run("Split", { + {"input", &k}, {"output", &curKs[b]} + }, {}, {{"axis", 1}, {"start", maxLen * (b + 1) - seqLens[b]}, {"end", maxLen * (b + 1)}}); + excutor.Run("Split", { + {"input", &v}, {"output", &curVs[b]} + }, {}, {{"axis", 1}, {"start", maxLen * (b + 1) - seqLens[b]}, {"end", maxLen * (b + 1)}}); + total += seqLens[b]; + } + + k.Reshape({1, k.dims[0], k.dims[1], k.dims[2]}); + v.Reshape({1, v.dims[0], v.dims[1], v.dims[2]}); + PermuteSelf(k, {0, 2, 1, 3}); + PermuteSelf(v, {0, 2, 1, 3}); + + std::vector pointersK, pointersV; + int unitLen = op.intParams.find("unitLen")->second; + for (int b = 0; b < batch; b++) { + pointersK.push_back(&curKs[b]); + pointersV.push_back(&curVs[b]); + for (int i = 0; i < 2; i++) { + auto cache = i == 0 ? pastKeys[layerId][b] : pastValues[layerId][b]; + auto cur = i == 0 ? &curKs[b] : &curVs[b]; + while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1])) + || (cache->dims.size() > 0 && cache->dims[1] + cur->dims[1] > cache->expansionDims[1])) { + std::vector newDims; + if (cache->Count(0) == 0 || cache->dims.size() == 0) { + newDims = std::vector {cur->dims[0], ((cur->dims[1] - 1) / unitLen + 1) * unitLen, cur->dims[2]}; + } else { + newDims = cache->dims; + newDims[1] += ((cur->dims[1] - 1) / unitLen + 1) * unitLen; + } + cache->Expansion(newDims); + } + } + } + + CatDirectBatch(pastKeys[layerId], pointersK, 1); + CatDirectBatch(pastValues[layerId], pointersV, 1); + + int q0 = q.dims[2], k0 = k.dims[2], dims = q.dims[3]; + q.Reshape({batch, maxLen, q0, dims}); + PermuteSelf(q, {0, 2, 1, 3}); + q.Reshape({batch * q0, maxLen, -1}); + + k.Reshape({batch, maxLen, k0, dims}); + PermuteSelf(k, {0, 2, 1, 3}); + k.Reshape({batch * k0, maxLen, -1}); + + v.Reshape({batch, maxLen, k0, dims}); + PermuteSelf(v, {0, 2, 1, 3}); + v.Reshape({batch * k0, maxLen, -1}); + + auto &attenOutput = *allDatas[op.datas.find("output")->second]; + Attention(q, k, v, *allDatas[op.datas.find("mask")->second], attenOutput, q.dims[0] / k.dims[0], 1.0 / sqrt(head_dim), 1); + PermuteSelf(attenOutput, {1, 0, 2}); + attenOutput.Reshape({maxLen, batch, -1}); + PermuteSelf(attenOutput, {1, 0, 2}); + attenOutput.Reshape({1, -1, attenOutput.dims[2]}); } else { + total = 0; std::vector curQs, curKs, curVs, curOutputs; curQs.resize(batch); curKs.resize(batch); @@ -354,7 +470,7 @@ namespace fastllm { int unitLen = op.intParams.find("unitLen")->second; for (int b = 0; b < batch; b++) { for (int i = 0; i < 2; i++) { - auto cache = allDatas[op.datas.find(i == 0 ? "k" : "v")->second + "." + std::to_string(b)]; + auto cache = i == 0 ? pastKeys[layerId][b] : pastValues[layerId][b]; auto cur = i == 0 ? &curKs[b] : &curVs[b]; while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1])) || (cache->dims.size() > 0 && cache->dims[1] + cur->dims[1] > cache->expansionDims[1])) { @@ -366,7 +482,7 @@ namespace fastllm { newDims[1] += ((cur->dims[1] - 1) / unitLen + 1) * unitLen; } cache->Expansion(newDims); - } + } excutor.Run("CatDirect", { {"input0", cache}, {"input1", cur} }, {}, {{"axis", 1}}); @@ -374,10 +490,10 @@ namespace fastllm { } for (int b = 0; b < batch; b++) { - std::string sb = "." + std::to_string(b); - Data *k = allDatas[op.datas.find("k")->second + sb]; - Data *v = allDatas[op.datas.find("v")->second + sb]; - Data *mask = allDatas[op.datas.find("mask")->second + sb]; + Data *k = pastKeys[layerId][b]; + Data *v = pastValues[layerId][b]; + Data *mask = masks[b]; + excutor.Run("Attention", { {"q", (Data*)&curQs[b]}, {"k", k}, {"v", v}, {"mask", mask}, {"output", (Data*)&curOutputs[b]} @@ -401,7 +517,6 @@ namespace fastllm { {"input", output}, {"axis", &axisData} }, {}, {}); } - auto lastOutput = allDatas[op.datas.find("output")->second]; for (int b = 0; b < batch; b++) { Data *output = (Data*)&curOutputs[b]; @@ -420,13 +535,16 @@ namespace fastllm { } } } else if (op.type == "SplitLastTokenStates") { + int total = 0, maxLen = 0; std::vector seqLens; { auto data = allDatas[op.datas.find("seqLens")->second]; for (int i = 0; i < data->Count(0); i++) { seqLens.push_back(((int*)data->cpuData)[i]); + total += seqLens.back(); + maxLen = std::max(maxLen, seqLens.back()); } - } + } auto input = allDatas[op.datas.find("input")->second]; auto output = allDatas[op.datas.find("output")->second]; int len = input->dims[1]; @@ -435,24 +553,52 @@ namespace fastllm { output->FakeFrom(*input, 0); } else if (input->dims[0] == 1 && seqLens.size() > 1) { auto lastOutput = allDatas[op.datas.find("output")->second]; - int total = 0; - for (int b = 0; b < seqLens.size(); b++) { - Data output; - excutor.Run("Split", { - {"input", input}, {"output", (Data*)&output} - }, {}, {{"axis", 1}, {"start", total + seqLens[b] - 1}, {"end", total + seqLens[b]}}); - if (b == 0) { - lastOutput->dataType = output.dataType; - std::vector dims = output.dims; - dims[1] = 0; - lastOutput->Resize(dims); - dims[1] = seqLens.size(); - lastOutput->Expansion(dims); + if (total != input->dims[1]) { + int total = 0; + for (int b = 0; b < seqLens.size(); b++) { + Data output; + excutor.Run("Split", { + {"input", input}, {"output", (Data*)&output} + }, {}, {{"axis", 1}, {"start", maxLen * (b + 1) - 1}, {"end", maxLen * (b + 1)}}); + if (b == 0) { + lastOutput->dataType = output.dataType; + std::vector dims = output.dims; + dims[1] = 0; + lastOutput->Resize(dims); + dims[1] = seqLens.size(); + lastOutput->Expansion(dims); + } + excutor.Run("CatDirect", { + {"input0", lastOutput}, {"input1", (Data*)&output} + }, {}, {{"axis", 1}}); + total += seqLens[b]; + } + } else { + if (total == seqLens.size()) { + excutor.Run("Mul", { + {"input", (Data*)input}, {"output", (Data*)lastOutput} + }, {{"v", 1.0f}}, {}); + } else { + int total = 0; + for (int b = 0; b < seqLens.size(); b++) { + Data output; + excutor.Run("Split", { + {"input", input}, {"output", (Data*)&output} + }, {}, {{"axis", 1}, {"start", total + seqLens[b] - 1}, {"end", total + seqLens[b]}}); + if (b == 0) { + lastOutput->dataType = output.dataType; + std::vector dims = output.dims; + dims[1] = 0; + lastOutput->Resize(dims); + dims[1] = seqLens.size(); + lastOutput->Expansion(dims); + } + excutor.Run("CatDirect", { + {"input0", lastOutput}, {"input1", (Data*)&output} + }, {}, {{"axis", 1}}); + total += seqLens[b]; + } } - excutor.Run("CatDirect", { - {"input0", lastOutput}, {"input1", (Data*)&output} - }, {}, {{"axis", 1}}); - total += seqLens[b]; } } else { excutor.Run("Split", { diff --git a/src/models/graphllm.cpp b/src/models/graphllm.cpp index d97e606..7b054b2 100644 --- a/src/models/graphllm.cpp +++ b/src/models/graphllm.cpp @@ -116,6 +116,12 @@ namespace fastllm { for (auto &it : weight.weight) { weightDicts[it.first] = &it.second; } + std::vector > pastKeys, pastValues; + std::vector masks; + pastKeys.resize(block_cnt); + pastValues.resize(block_cnt); + masks.push_back((Data*)&attentionMask); + Data atype = Data(this->dataType); std::map inputs = { {"inputIds", (Data*)&inputIds}, @@ -126,11 +132,11 @@ namespace fastllm { {"seqLens", (Data*)&seqLensData} }; for (int i = 0; i < block_cnt; i++) { - inputs.insert({"pastKey." + std::to_string(i), (Data*)&pastKeyValues[i].first}); - inputs.insert({"pastValue." + std::to_string(i), (Data*)&pastKeyValues[i].second}); + pastKeys[i].push_back((Data*)&pastKeyValues[i].first); + pastValues[i].push_back((Data*)&pastKeyValues[i].second); } Data logits, topk; - RunComputeGraph(graph, this->deviceMap, inputs, weightDicts, {{"logits", (Data*)&logits}}); + RunComputeGraph(graph, this->deviceMap, inputs, weightDicts, {{"logits", (Data*)&logits}}, pastKeys, pastValues, masks); std::vector lastRet; { ToDataType(logits, DataType::FLOAT32); @@ -179,10 +185,13 @@ namespace fastllm { } int seqLen = inputIds.dims[1]; Data allPositionIds; - allPositionIds.CopyFrom(*(Data*)positionIds[0]); - allPositionIds.Expansion({1, seqLen}); - for (int i = 1; i < batch; i++) { - CatDirect(allPositionIds, *(Data*)positionIds[i], 1); + int pos = 0; + allPositionIds.dataType = positionIds[0]->dataType; + allPositionIds.Resize({1, seqLen}); + allPositionIds.Allocate(); + for (int i = 0; i < batch; i++) { + memcpy(allPositionIds.cpuData + pos, positionIds[i]->cpuData, (size_t)positionIds[i]->GetBytes()); + pos += positionIds[i]->GetBytes(); } std::map weightDicts; for (auto &it : weight.weight) { @@ -196,16 +205,87 @@ namespace fastllm { {"sin", &sinData}, {"cos", &cosData}, {"seqLens", &seqLensData} }; + + std::vector > pastKeys, pastValues; + std::vector masks; + pastKeys.resize(block_cnt); + pastValues.resize(block_cnt); + for (int i = 0; i < block_cnt; i++) { + pastKeys[i].resize(batch); + pastValues[i].resize(batch); + } + masks.resize(batch); for (int b = 0; b < batch; b++) { - std::string sb = std::to_string(b); - inputs.insert({"attentionMask." + sb, attentionMask[b]}); + masks[b] = attentionMask[b]; for (int i = 0; i < block_cnt; i++) { - inputs.insert({"pastKey." + std::to_string(i) + "." + sb, pastKeyValues[b * block_cnt + i].first}); - inputs.insert({"pastValue." + std::to_string(i) + "." + sb, pastKeyValues[b * block_cnt + i].second}); + pastKeys[i][b] = pastKeyValues[b * block_cnt + i].first; + pastValues[i][b] = pastKeyValues[b * block_cnt + i].second; + } + for (int i = 0; i < block_cnt; i++) { + if (GetKVCacheInCPU()) { + pastKeyValues[b * block_cnt + i].first->lockInCPU = true; + pastKeyValues[b * block_cnt + i].second->lockInCPU = true; + } else { + if (pastKeyValues[b * block_cnt + i].first->dataDevice == DataDevice::CUDA) { + break; + } + pastKeyValues[b * block_cnt + i].first->ToDevice(DataDevice::CUDA); + pastKeyValues[b * block_cnt + i].second->ToDevice(DataDevice::CUDA); + } + } + } + + // 拼batch, 把短句补长 + Data curAttentionMask; + Data realInputIds; + int maxLen = 0, totalLen = 0; + if (batch > 1 && seqLen != seqLens.size()) { + for (int i = 0; i < batch; i++) { + maxLen = std::max(maxLen, seqLens[i]); } + int totalLen = maxLen * batch; + Data tempInputIds; + Mul(inputIds, 1.0, tempInputIds); + ToDataType(tempInputIds, DataType::FLOAT32); + tempInputIds.ToDevice(DataDevice::CPU); + allPositionIds.ToDevice(DataDevice::CPU); + float *floatInputIds = (float*)tempInputIds.cpuData; + float *floatPositionIds = (float*)allPositionIds.cpuData; + + std::vector vmask = std::vector (batch * maxLen * maxLen, 0); + std::vector vpids = std::vector (batch * maxLen, 0); + std::vector ids = std::vector (batch * maxLen, 0.0); + for (int i = 0; i < batch; i++) { + int len = seqLens[i], base = maxLen - len; + for (int j = 0; j < len; j++) { + ids[i * maxLen + base + j] = (*floatInputIds++); + vpids[i * maxLen + base + j] = (*floatPositionIds++); + } + std::fill(vmask.data() + i * maxLen * maxLen, + vmask.data() + i * maxLen * maxLen + (maxLen - len) * maxLen, 1.0); + for (int j = maxLen - len; j < maxLen; j++) { + std::fill(vmask.data() + i * maxLen * maxLen + j * maxLen, + vmask.data() + i * maxLen * maxLen + j * maxLen + maxLen - len, 1.0); + } + for (int j = 0; j < len; j++) { + for (int k = j + 1; k < len; k++) { + vmask[i * maxLen * maxLen + (base + j) * maxLen + base + k] = 1; + } + } + } + + realInputIds.CopyFrom(Data(DataType::FLOAT32, {1, batch * maxLen}, ids)); + curAttentionMask.CopyFrom(Data(DataType::FLOAT32, {batch, maxLen, maxLen}, vmask)); + allPositionIds.CopyFrom(Data(DataType::FLOAT32, {1, batch * maxLen}, vpids)); + + ToDataType(curAttentionMask, this->dataType); + inputs.insert({"attentionMask", &curAttentionMask}); + inputs["inputIds"] = (Data*)&realInputIds; + inputs["positionIds"] = (Data*)&allPositionIds; } + Data logits, topk; - RunComputeGraph(graph, this->deviceMap, inputs, weightDicts, {{"logits", (Data*)&logits}}); + RunComputeGraph(graph, this->deviceMap, inputs, weightDicts, {{"logits", (Data*)&logits}}, pastKeys, pastValues, masks); ToDataType(logits, DataType::FLOAT32); std::vector curLogits; curLogits.resize(batch);