diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 3308f157..3e377767 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -128,7 +128,7 @@ template __global__ void HalfFC( half * __restrict__ a, half * __restrict__ b, half * __restrict__ c, const int N, const int M, const int K, - half scale) { + half scale, const int base) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 // support tensor core int tid = threadIdx.x; int bx = blockIdx.x; @@ -140,7 +140,7 @@ __global__ void HalfFC( int wrap0 = wid >> 1; int wrap1 = wid & 1; - if (stN + BN <= stK) { + if (base + stN + BN <= stK) { return; } @@ -199,7 +199,7 @@ __global__ void HalfFC( __syncthreads(); for (int i = 0; i < BN; i++) { - if (stN + i < stK + tid) { + if (base + stN + i < stK + tid) { cur[i][tid] = (half)0; } } @@ -210,13 +210,13 @@ __global__ void HalfFC( #endif } -void GpuQK(half *q, half *k, half *qk, int qlen, int klen, int dim, float scale) { +void GpuQK(half *q, half *k, half *qk, int qlen, int klen, int dim, float scale, int base) { const int BQ = 128, BK = 128, DIM = 128; dim3 blockDim(128); int BX = (qlen + BQ - 1) / BQ; int BY = (klen + BK - 1) / BK; dim3 gridDim(BX, BY); - HalfFC <<>> (q, k, qk, qlen, dim, klen, (half)scale); + HalfFC <<>> (q, k, qk, qlen, dim, klen, (half)scale, base); } template @@ -786,9 +786,9 @@ __global__ void FastllmSoftmaxKernelInner1(half* input, half *output, int outer, } template -__global__ void FastllmSoftmaxKernelInner1WithCausalMask(T* input, T *output, int outer, int channels) { +__global__ void FastllmSoftmaxKernelInner1WithCausalMask(T* input, T *output, int outer, int channels, int base) { int o = blockIdx.x; - FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, o + 1); + FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, o + base + 1); } template @@ -3005,7 +3005,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const return true; } - if (q1 >= 1024) { + if (q1 >= 1024 || (q1 > 1 && q1 != k1)) { float *qk = (float *) FastllmCudaMalloc(q1 * k1 * sizeof(float)); float beta = 0, one = 1; auto fastllmCublasHandle = getFastllmCublasHandle(); @@ -3029,7 +3029,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const if (batch == 1 && maskd == nullptr) { CausalMask<256, float> <<>>(qk, 0, q1, k1); - FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1); + FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1, k1 - q1); } else { if (maskd) { SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, -10000, q1 * k1); @@ -3136,8 +3136,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)); half beta = __float2half_rn(0.0f), one = __float2half_rn(1.0f), hscale = __float2half_rn(scale); - - if (q1 >= 1024) { + if (q1 >= 1024 || (q1 > 1 && q1 != k1)) { int alignQ1 = q1, alignK1 = k1; bool useFastAttn = getCudaInfos()->hasTensorCore && batch == 1 && (q2 == 128 && v2 == 128); if (useFastAttn) { @@ -3154,8 +3153,8 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co //DeviceSync(); //auto st = std::chrono::system_clock::now(); if (useFastAttn) { - GpuQK(qd + i * q.Count(1), kd + (i / group) * k.Count(1), qk, alignQ1, alignK1, q2, scale); - FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, alignK1); + GpuQK(qd + i * q.Count(1), kd + (i / group) * k.Count(1), qk, alignQ1, alignK1, q2, scale, k1 - q1); + FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, alignK1, k1 - q1); status = cublasHgemmStridedBatched(fastllmCublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, v2, q1, alignK1, &one, @@ -3193,7 +3192,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co if (batch == 1 && maskd == nullptr) { CausalMask<256, half> <<>>(qk, __float2half_rn(0), q1, k1); - FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1); + FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1, k1 - q1); } else { SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, __float2half_rn(-10000), q1 * k1); int outer = q1; diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index c5739fa2..ee112750 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -769,6 +769,20 @@ auto st = std::chrono::system_clock::now(); ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks, positionIds, seqLens, pastKeyValues, generationConfigs, tokensManager, &logits); + } else { + if (seqLens[0] > 8192) { + int len = seqLens[0]; + int first = 8192, part = 2048; + for (int st = 0; st < len; ) { + int curLen = std::min(st == 0 ? first : part, len - st); + Data curInput, curPositionIds; + Split(inputIds, 1, st, st + curLen, curInput); + Split(*positionIds[0], 1, st, st + curLen, curPositionIds); + + ret = std::vector {model->Forward(curInput, Data(), curPositionIds, + *pastKeyValue1, generationConfigs[0], tokensManager, logits[0])}; + st += curLen; + } } else { ret = std::vector {model->Forward(inputIds, attentionMasks[0] == nullptr ? Data() : *attentionMasks[0], diff --git a/src/models/chatglm.cpp b/src/models/chatglm.cpp index 7136b44b..7c6a272d 100644 --- a/src/models/chatglm.cpp +++ b/src/models/chatglm.cpp @@ -87,15 +87,15 @@ namespace fastllm { this->gmask_token_id = 64790; this->bos_token_id = 64792; } - if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) { - UpdateRotaryPosEmb(atof(this->weight.dicts["rope_ratio"].c_str())); - } if (this->weight.dicts.find("layernorm_epsilon") != this->weight.dicts.end()) { this->layernorm_epsilon = atof(this->weight.dicts["layernorm_epsilon"].c_str()); } if (this->weight.dicts.find("seq_length") != this->weight.dicts.end()) { max_positions = atoi(this->weight.dicts["seq_length"].c_str()); } + if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) { + UpdateRotaryPosEmb(atof(this->weight.dicts["rope_ratio"].c_str())); + } } int ChatGLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask, @@ -110,13 +110,14 @@ namespace fastllm { std::vector ChatGLMModel::ForwardBatch( int batch, const Data &inputIds, - const Data &attentionMaskOri, - const Data &positionIdsOri, + const Data &attentionMask, + const Data &positionIds, std::vector > &pastKeyValues, const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, std::vector *> *retLogits) { int maxLen = inputIds.dims[1]; + Data inputEmbeddings; Data attenInput; Data qkv, q, k, v; Data attnProbs; @@ -138,274 +139,232 @@ namespace fastllm { } // ChatGLM2 - Data inputIdsPermute, curInputIdsPermute; + Data inputIdsPermute; Permute(inputIds, {1, 0}, inputIdsPermute); + Embedding(inputIdsPermute, this->weight["transformer" + std::string((version == 2 ? ".embedding" : "")) + + ".word_embeddings.weight"], inputEmbeddings); + ToDataType(inputEmbeddings, this->dataType); - int seqLen = inputIdsPermute.dims[0]; - int perSeqLen = std::min(seqLen, 1024); - - for (int seqStart = 0; seqStart < seqLen; seqStart += perSeqLen) { - Data inputEmbeddings; - int curSeqLen = std::min(perSeqLen, seqLen - seqStart); - Data *positionIdsPointer, *attentionMaskPointer; - Data positionIdsNew, attentionMaskNew; - if (curSeqLen == seqLen) { - Embedding(inputIdsPermute, this->weight["transformer" + std::string((version == 2 ? ".embedding" : "")) + - ".word_embeddings.weight"], inputEmbeddings); - attentionMaskPointer = (Data*)&attentionMaskOri; - positionIdsPointer = (Data*)&positionIdsOri; - } else { - std::vector curMasks = std::vector (curSeqLen * (seqStart + curSeqLen)); - int idx = 0; - for (int i = 0; i < curSeqLen; i++) { - for (int j = 0; j < seqStart + curSeqLen; j++) { - curMasks[idx++] = (seqStart + i) < j; - } - } - attentionMaskNew.CopyFrom(Data(FLOAT32, {curSeqLen, seqStart + curSeqLen}, curMasks)); - ToDataType(attentionMaskNew, this->dataType); - Split(positionIdsOri, -1, seqStart, seqStart + curSeqLen, positionIdsNew); - positionIdsPointer = &positionIdsNew; - attentionMaskPointer = &attentionMaskNew; - - Split(inputIdsPermute, 0, seqStart, seqStart + curSeqLen, curInputIdsPermute); - Embedding(curInputIdsPermute, this->weight["transformer" + std::string((version == 2 ? ".embedding" : "")) + - ".word_embeddings.weight"], inputEmbeddings); + Data &hiddenStates = inputEmbeddings; + for (int i = 0; i < block_cnt; i++) { + ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); + if (version == 1) { + std::string inputLNWeightName = "transformer.layers." + std::to_string(i) + ".input_layernorm.weight"; + std::string inputLNBiasName = "transformer.layers." + std::to_string(i) + ".input_layernorm.bias"; + LayerNorm(hiddenStates, weight[inputLNWeightName], weight[inputLNBiasName], -1, attenInput); + } else if (version == 2) { + std::string inputRMSWeightName = + "transformer.encoder.layers." + std::to_string(i) + ".input_layernorm.weight"; + RMSNorm(hiddenStates, weight[inputRMSWeightName], layernorm_epsilon, attenInput); } - - Data &positionIds = *positionIdsPointer; - Data &attentionMask = *attentionMaskPointer; - - ToDataType(inputEmbeddings, this->dataType); - - Data &hiddenStates = inputEmbeddings; - for (int i = 0; i < block_cnt; i++) { - ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); - if (version == 1) { - std::string inputLNWeightName = "transformer.layers." + std::to_string(i) + ".input_layernorm.weight"; - std::string inputLNBiasName = "transformer.layers." + std::to_string(i) + ".input_layernorm.bias"; - LayerNorm(hiddenStates, weight[inputLNWeightName], weight[inputLNBiasName], -1, attenInput); - } else if (version == 2) { - std::string inputRMSWeightName = - "transformer.encoder.layers." + std::to_string(i) + ".input_layernorm.weight"; - RMSNorm(hiddenStates, weight[inputRMSWeightName], layernorm_epsilon, attenInput); - } - std::string qkvWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.weight"; - std::string qkvBiasName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.bias"; - if (!adapterName.empty()) { - std::string peftType = weight.peftDict[adapterName]["peft_type"]; - if (peftType == "LORA") { - std::string loraAWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.lora_A." + adapterName + ".weight"; - std::string loraBWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.lora_B." + adapterName + ".weight"; - LoraLayer(attenInput, weight[qkvWeightName], weight[loraAWeightName], weight[loraBWeightName], weight[qkvBiasName], qkv, weight.peftDict[adapterName]); - } else if (peftType == "IA3") { - std::string ia3WeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.ia3_l" + adapterName + ".weight"; - IA3Layer(attenInput, weight[qkvWeightName], weight[ia3WeightName], weight[qkvBiasName], qkv, weight.peftDict[adapterName]); - } - } else { - Linear(attenInput, weight[qkvWeightName], weight[qkvBiasName], qkv); - } - if (version == 1) { - qkv.Reshape({qkv.dims[0], qkv.dims[1], num_attention_heads, -1}); - int per = qkv.dims.back() / 3; - Split(qkv, -1, 0, per, q); - Split(qkv, -1, per, per * 2, k); - Split(qkv, -1, per * 2, per * 3, v); - fastllm::RotatePosition2D(q, positionIds, sinData, cosData, rotary_dim); - fastllm::RotatePosition2D(k, positionIds, sinData, cosData, rotary_dim); - } else if (version == 2) { - int qLen = embed_dim, kvLen = (qkv.dims.back() - embed_dim) / 2; - Split(qkv, -1, 0, qLen, q); - Split(qkv, -1, qLen, qLen + kvLen, k); - Split(qkv, -1, qLen + kvLen, qLen + kvLen + kvLen, v); - q.Reshape({q.dims[0], q.dims[1], -1, embed_dim / num_attention_heads}); - k.Reshape({k.dims[0], k.dims[1], -1, embed_dim / num_attention_heads}); - v.Reshape({v.dims[0], v.dims[1], -1, embed_dim / num_attention_heads}); - fastllm::NearlyRotatePosition2D(q, positionIds, sinData, cosData, rotary_dim); - fastllm::NearlyRotatePosition2D(k, positionIds, sinData, cosData, rotary_dim); + std::string qkvWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.weight"; + std::string qkvBiasName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.bias"; + if (!adapterName.empty()) { + std::string peftType = weight.peftDict[adapterName]["peft_type"]; + if (peftType == "LORA") { + std::string loraAWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.lora_A." + adapterName + ".weight"; + std::string loraBWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.lora_B." + adapterName + ".weight"; + LoraLayer(attenInput, weight[qkvWeightName], weight[loraAWeightName], weight[loraBWeightName], weight[qkvBiasName], qkv, weight.peftDict[adapterName]); + } else if (peftType == "IA3") { + std::string ia3WeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.ia3_l" + adapterName + ".weight"; + IA3Layer(attenInput, weight[qkvWeightName], weight[ia3WeightName], weight[qkvBiasName], qkv, weight.peftDict[adapterName]); } + } else { + Linear(attenInput, weight[qkvWeightName], weight[qkvBiasName], qkv); + } + if (version == 1) { + qkv.Reshape({qkv.dims[0], qkv.dims[1], num_attention_heads, -1}); + int per = qkv.dims.back() / 3; + Split(qkv, -1, 0, per, q); + Split(qkv, -1, per, per * 2, k); + Split(qkv, -1, per * 2, per * 3, v); + fastllm::RotatePosition2D(q, positionIds, sinData, cosData, rotary_dim); + fastllm::RotatePosition2D(k, positionIds, sinData, cosData, rotary_dim); + } else if (version == 2) { + int qLen = embed_dim, kvLen = (qkv.dims.back() - embed_dim) / 2; + Split(qkv, -1, 0, qLen, q); + Split(qkv, -1, qLen, qLen + kvLen, k); + Split(qkv, -1, qLen + kvLen, qLen + kvLen + kvLen, v); + q.Reshape({q.dims[0], q.dims[1], -1, embed_dim / num_attention_heads}); + k.Reshape({k.dims[0], k.dims[1], -1, embed_dim / num_attention_heads}); + v.Reshape({v.dims[0], v.dims[1], -1, embed_dim / num_attention_heads}); + fastllm::NearlyRotatePosition2D(q, positionIds, sinData, cosData, rotary_dim); + fastllm::NearlyRotatePosition2D(k, positionIds, sinData, cosData, rotary_dim); + } - Data &pastKey = pastKeyValues[i].first, &pastValue = pastKeyValues[i].second; - if (GetKVCacheInCPU()) { - pastKey.lockInCPU = true; - pastValue.lockInCPU = true; - } else { - pastKey.ToDevice(DataDevice::CUDA); - pastValue.ToDevice(DataDevice::CUDA); - }; - - k.Resize({k.dims[0], k.dims[1] * k.dims[2], k.dims[3]}); - v.Resize({v.dims[0], v.dims[1] * v.dims[2], v.dims[3]}); + Data &pastKey = pastKeyValues[i].first, &pastValue = pastKeyValues[i].second; + if (GetKVCacheInCPU()) { + pastKey.lockInCPU = true; + pastValue.lockInCPU = true; + } else { + pastKey.ToDevice(DataDevice::CUDA); + pastValue.ToDevice(DataDevice::CUDA); + }; + k.Resize({k.dims[0], k.dims[1] * k.dims[2], k.dims[3]}); + v.Resize({v.dims[0], v.dims[1] * v.dims[2], v.dims[3]}); - PermuteSelf(k, {1, 0, 2}); - PermuteSelf(v, {1, 0, 2}); + PermuteSelf(k, {1, 0, 2}); + PermuteSelf(v, {1, 0, 2}); - int unitLen = 64; - #ifdef USE_CUDA - unitLen = 128; - #endif - while ((pastKey.dims.size() == 0 && - (pastKey.expansionDims.size() == 0 || k.dims[1] > pastKey.expansionDims[1])) - || (pastKey.dims.size() > 0 && (pastKey.expansionDims.size() == 0 || - pastKey.dims[1] + k.dims[1] > pastKey.expansionDims[1]))) { - std::vector newDims; - if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) { - newDims = std::vector{k.dims[0], ((k.dims[1] - 1) / unitLen + 1) * unitLen, k.dims[2]}; - if (generationConfig.output_token_limit > 0) { - newDims[1] = k.dims[1] + generationConfig.output_token_limit; - } - } else { - newDims = pastKey.dims; - newDims[1] += ((k.dims[1] - 1) / unitLen + 1) * unitLen; + int unitLen = 64; +#ifdef USE_CUDA + unitLen = 128; +#endif + while ((pastKey.dims.size() == 0 && + (pastKey.expansionDims.size() == 0 || k.dims[1] > pastKey.expansionDims[1])) + || (pastKey.dims.size() > 0 && (pastKey.expansionDims.size() == 0 || + pastKey.dims[1] + k.dims[1] > pastKey.expansionDims[1]))) { + std::vector newDims; + if (pastKey.Count(0) == 0 || pastKey.dims.size() == 0) { + newDims = std::vector{k.dims[0], ((k.dims[1] - 1) / unitLen + 1) * unitLen, k.dims[2]}; + if (generationConfig.output_token_limit > 0) { + newDims[1] = k.dims[1] + generationConfig.output_token_limit; } - pastKey.Expansion(newDims); + } else { + newDims = pastKey.dims; + newDims[1] += ((k.dims[1] - 1) / unitLen + 1) * unitLen; } + pastKey.Expansion(newDims); + } - while ((pastValue.dims.size() == 0 && - (pastValue.expansionDims.size() == 0 || v.dims[1] > pastValue.expansionDims[1])) - || (pastValue.dims.size() > 0 && (pastValue.expansionDims.size() == 0 || - pastValue.dims[1] + v.dims[1] > pastValue.expansionDims[1]))) { - std::vector newDims; - if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) { - newDims = std::vector{v.dims[0], ((v.dims[1] - 1) / unitLen + 1) * unitLen, v.dims[2]}; - if (generationConfig.output_token_limit > 0) { - newDims[1] = k.dims[1] + generationConfig.output_token_limit; - } - } else { - newDims = pastValue.dims; - newDims[1] += ((v.dims[1] - 1) / unitLen + 1) * unitLen; + while ((pastValue.dims.size() == 0 && + (pastValue.expansionDims.size() == 0 || v.dims[1] > pastValue.expansionDims[1])) + || (pastValue.dims.size() > 0 && (pastValue.expansionDims.size() == 0 || + pastValue.dims[1] + v.dims[1] > pastValue.expansionDims[1]))) { + std::vector newDims; + if (pastValue.Count(0) == 0 || pastValue.dims.size() == 0) { + newDims = std::vector{v.dims[0], ((v.dims[1] - 1) / unitLen + 1) * unitLen, v.dims[2]}; + if (generationConfig.output_token_limit > 0) { + newDims[1] = k.dims[1] + generationConfig.output_token_limit; } - pastValue.Expansion(newDims); - } - CatDirect(pastKey, k, 1); - CatDirect(pastValue, v, 1); - std::vector outputSize = {q.dims[1], q.dims[2], q.dims[0], pastKey.dims[1]}; - q.Reshape({q.dims[0], q.dims[1] * q.dims[2], q.dims[3]}); - PermuteSelf(q, {1, 0, 2}); - Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1); - /* - // 1.2 Attention - // 1.2.0 q * k^T - q.Reshape({pastKey.dims[0], -1, q.dims[2]}); - MatMulTransB(q, pastKey, attnProbs, 1.0 / (scale_attn * (i + 1))); - attnProbs.Reshape(outputSize); - // 1.2.1 Mask - if (attentionMask.dims.size() != 0) { - AttentionMask(attnProbs, attentionMask, -10000); - } - // 1.2.2 softmax - Mul(attnProbs, i + 1, attnProbs); - Softmax(attnProbs, attnProbs, -1); - outputSize = {1, pastValue.dims[0], q.dims[1], pastValue.dims[1]}; - attnProbs.Reshape({outputSize[0] * outputSize[1], outputSize[2], -1}); - // 1.2.3 prob * v - - attnProbs.Reshape({pastValue.dims[0], -1, attnProbs.dims[2]}); - MatMul(attnProbs, pastValue, contextLayer); - */ + } else { + newDims = pastValue.dims; + newDims[1] += ((v.dims[1] - 1) / unitLen + 1) * unitLen; + } + pastValue.Expansion(newDims); + } + CatDirect(pastKey, k, 1); + CatDirect(pastValue, v, 1); + std::vector outputSize = {q.dims[1], q.dims[2], q.dims[0], pastKey.dims[1]}; + q.Reshape({q.dims[0], q.dims[1] * q.dims[2], q.dims[3]}); + PermuteSelf(q, {1, 0, 2}); + Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1); + +/* + // 1.2 Attention + // 1.2.0 q * k^T + q.Reshape({pastKey.dims[0], -1, q.dims[2]}); + MatMulTransB(q, pastKey, attnProbs, 1.0 / (scale_attn * (i + 1))); + attnProbs.Reshape(outputSize); + // 1.2.1 Mask + if (attentionMask.dims.size() != 0) { + AttentionMask(attnProbs, attentionMask, -10000); + } + // 1.2.2 softmax + Mul(attnProbs, i + 1, attnProbs); + Softmax(attnProbs, attnProbs, -1); + outputSize = {1, pastValue.dims[0], q.dims[1], pastValue.dims[1]}; + attnProbs.Reshape({outputSize[0] * outputSize[1], outputSize[2], -1}); + // 1.2.3 prob * v + + attnProbs.Reshape({pastValue.dims[0], -1, attnProbs.dims[2]}); + MatMul(attnProbs, pastValue, contextLayer); +*/ + + contextLayer.Reshape({batch, num_attention_heads, maxLen, -1}); + PermuteSelf(contextLayer, {2, 0, 1, 3}); + contextLayer.Reshape({contextLayer.dims[0], contextLayer.dims[1], embed_dim}); - contextLayer.Reshape({batch, num_attention_heads, curSeqLen, -1}); - PermuteSelf(contextLayer, {2, 0, 1, 3}); - contextLayer.Reshape({contextLayer.dims[0], contextLayer.dims[1], embed_dim}); + // 1.2.4 dense + std::string denseWeightName = weightPre + std::to_string(i) + weightMiddle + ".dense.weight"; + std::string denseBiasName = weightPre + std::to_string(i) + weightMiddle + ".dense.bias"; + Linear(contextLayer, weight[denseWeightName], weight[denseBiasName], attnOutput); - // 1.2.4 dense - std::string denseWeightName = weightPre + std::to_string(i) + weightMiddle + ".dense.weight"; - std::string denseBiasName = weightPre + std::to_string(i) + weightMiddle + ".dense.bias"; - Linear(contextLayer, weight[denseWeightName], weight[denseBiasName], attnOutput); + // 1.3 + if (GetVersion() == 1) { + float alpha = sqrt(2 * block_cnt); + Mul(attenInput, alpha, hiddenStates); + AddTo(hiddenStates, attnOutput); + std::string postLNWeightName = + "transformer.layers." + std::to_string(i) + ".post_attention_layernorm.weight"; + std::string postLNBiasName = + "transformer.layers." + std::to_string(i) + ".post_attention_layernorm.bias"; + LayerNorm(hiddenStates, weight[postLNWeightName], weight[postLNBiasName], -1, mlpInput); + // 1.4 MLP + std::string fcInKeyName = "transformer.layers." + std::to_string(i) + ".mlp.dense_h_to_4h"; + std::string fcOutKeyName = "transformer.layers." + std::to_string(i) + ".mlp.dense_4h_to_h"; + Linear(mlpInput, weight[fcInKeyName + ".weight"], weight[fcInKeyName + ".bias"], middle); + GeluNew(middle, middle); + Linear(middle, weight[fcOutKeyName + ".weight"], weight[fcOutKeyName + ".bias"], hiddenStates); + AddTo(hiddenStates, mlpInput, alpha); + } else { + AddTo(hiddenStates, attnOutput); + std::string postRMSWeightName = + "transformer.encoder.layers." + std::to_string(i) + ".post_attention_layernorm.weight"; + Mul(hiddenStates, 1.0, temp); + RMSNorm(hiddenStates, weight[postRMSWeightName], this->layernorm_epsilon, mlpInput); + // 1.4 MLP + std::string fcInKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_h_to_4h"; + std::string fcOutKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_4h_to_h"; - // 1.3 - if (GetVersion() == 1) { - float alpha = sqrt(2 * block_cnt); - Mul(attenInput, alpha, hiddenStates); - AddTo(hiddenStates, attnOutput); - std::string postLNWeightName = - "transformer.layers." + std::to_string(i) + ".post_attention_layernorm.weight"; - std::string postLNBiasName = - "transformer.layers." + std::to_string(i) + ".post_attention_layernorm.bias"; - LayerNorm(hiddenStates, weight[postLNWeightName], weight[postLNBiasName], -1, mlpInput); - // 1.4 MLP - std::string fcInKeyName = "transformer.layers." + std::to_string(i) + ".mlp.dense_h_to_4h"; - std::string fcOutKeyName = "transformer.layers." + std::to_string(i) + ".mlp.dense_4h_to_h"; - Linear(mlpInput, weight[fcInKeyName + ".weight"], weight[fcInKeyName + ".bias"], middle); - GeluNew(middle, middle); - Linear(middle, weight[fcOutKeyName + ".weight"], weight[fcOutKeyName + ".bias"], hiddenStates); - AddTo(hiddenStates, mlpInput, alpha); + if (CanRunLinearEx(LinearExType::ExSwiglu)) { + LinearEx(mlpInput, weight[fcInKeyName + ".weight"], weight[fcInKeyName + ".bias"], middle2, LinearExType::ExSwiglu); } else { - AddTo(hiddenStates, attnOutput); - std::string postRMSWeightName = - "transformer.encoder.layers." + std::to_string(i) + ".post_attention_layernorm.weight"; - Mul(hiddenStates, 1.0, temp); - RMSNorm(hiddenStates, weight[postRMSWeightName], this->layernorm_epsilon, mlpInput); - // 1.4 MLP - std::string fcInKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_h_to_4h"; - std::string fcOutKeyName = "transformer.encoder.layers." + std::to_string(i) + ".mlp.dense_4h_to_h"; - - if (CanRunLinearEx(LinearExType::ExSwiglu)) { - LinearEx(mlpInput, weight[fcInKeyName + ".weight"], weight[fcInKeyName + ".bias"], middle2, LinearExType::ExSwiglu); - } else { - Linear(mlpInput, weight[fcInKeyName + ".weight"], weight[fcInKeyName + ".bias"], middle); - Swiglu(middle, middle2); - } - - Linear(middle2, weight[fcOutKeyName + ".weight"], weight[fcOutKeyName + ".bias"], hiddenStates); - AddTo(hiddenStates, temp); + Linear(mlpInput, weight[fcInKeyName + ".weight"], weight[fcInKeyName + ".bias"], middle); + Swiglu(middle, middle2); } + + Linear(middle2, weight[fcOutKeyName + ".weight"], weight[fcOutKeyName + ".bias"], hiddenStates); + AddTo(hiddenStates, temp); } - - if (seqStart + curSeqLen < seqLen) { - continue; - } + } + + Data logits, topk; + Data tempHiddenStates; + Data *lastHiddenStates; + if (maxLen > 1) { + Split(hiddenStates, 0, maxLen - 1, maxLen, tempHiddenStates); + lastHiddenStates = &tempHiddenStates; + } else { + lastHiddenStates = &hiddenStates; + } - Data logits, topk; - Data tempHiddenStates; - Data *lastHiddenStates; - if (curSeqLen > 1) { - Split(hiddenStates, 0, curSeqLen - 1, curSeqLen, tempHiddenStates); - lastHiddenStates = &tempHiddenStates; + { + auto &hiddenStates = *lastHiddenStates; + if (version == 1) { + LayerNorm(hiddenStates, weight["transformer.final_layernorm.weight"], + weight["transformer.final_layernorm.bias"], -1, hiddenStates); + Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); } else { - lastHiddenStates = &hiddenStates; + RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], this->layernorm_epsilon, hiddenStates); + Linear(hiddenStates, weight["transformer.output_layer.weight"], Data(), logits); } - { - auto &hiddenStates = *lastHiddenStates; - if (version == 1) { - LayerNorm(hiddenStates, weight["transformer.final_layernorm.weight"], - weight["transformer.final_layernorm.bias"], -1, hiddenStates); - Linear(hiddenStates, weight["lm_head.weight"], Data(), logits); - } else { - RMSNorm(hiddenStates, weight["transformer.encoder.final_layernorm.weight"], this->layernorm_epsilon, hiddenStates); - Linear(hiddenStates, weight["transformer.output_layer.weight"], Data(), logits); + ToDataType(logits, DataType::FLOAT32); + if (generationConfig.output_logits && retLogits != nullptr) { + int size = logits.dims.back(); + logits.ToDevice(DataDevice::CPU); + for (int b = 0; b < batch; b++) { + int base = b; + (*retLogits)[b]->resize(size); + memcpy((float *) (*retLogits)[b]->data(), ((float *) logits.cpuData) + base * size, + size * logits.unitSize); } - - ToDataType(logits, DataType::FLOAT32); -//logits.ToDevice(DataDevice::CPU); -//logits.Print(); - if (generationConfig.output_logits && retLogits != nullptr) { - int size = logits.dims.back(); - logits.ToDevice(DataDevice::CPU); - for (int b = 0; b < batch; b++) { - int base = b; - (*retLogits)[b]->resize(size); - memcpy((float *) (*retLogits)[b]->data(), ((float *) logits.cpuData) + base * size, - size * logits.unitSize); - } + } + if (generationConfig.IsSimpleGreedy()) { + TopK(logits, topk, 1); + topk.ToDevice(DataDevice::CPU); + for (int b = 0; b < batch; b++) { + int base = b; + lastRet.push_back((int) (((float *) topk.cpuData)[base * 2] + 1e-3)); } - if (generationConfig.IsSimpleGreedy()) { - TopK(logits, topk, 1); - topk.ToDevice(DataDevice::CPU); - for (int b = 0; b < batch; b++) { - int base = b; - lastRet.push_back((int) (((float *) topk.cpuData)[base * 2] + 1e-3)); - } - } else if (!lastTokens.units.empty()) { - for (int b = 0; b < batch; b++) { - int base = b; - lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b])); - } + } else if (!lastTokens.units.empty()) { + for (int b = 0; b < batch; b++) { + int base = b; + lastRet.push_back(LLMSampling(logits, base, generationConfig, lastTokens.units[b])); } } } - return lastRet; }