diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index e03e103d..63a33b9e 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -275,7 +275,6 @@ namespace fastllm { Data &output = *(datas.find("output")->second); int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1; float scale = floatParams.find("scale") != floatParams.end() ? floatParams.find("scale")->second : 1.0; - output.Allocate(); int q0 = q.dims[0], q1 = q.dims[1], q2 = q.dims[2], k0 = k.dims[0], k1 = k.dims[1], v2 = v.dims[2]; float *qd = (float*)q.cpuData; @@ -283,13 +282,15 @@ namespace fastllm { float *vd = (float*)v.cpuData; float *maskd = (datas.find("mask")->second && mask.dims.size() > 0) ? (float*)mask.cpuData : nullptr; float *od = (float*)output.cpuData; + int batch = (mask.dims.size() == 3 ? mask.dims[0] : 1); + int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)); std::fill(od, od + output.Count(0), 0.0f); auto pool = GetPool(); std::vector > futures; for (int o = 0; o < q0; o++) { futures.push_back(pool->Submit(SingleAttention, qd + o * q.strides[0], kd + (o / group) * k.strides[0], vd + (o / group) * v.strides[0], - maskd ? (maskd + o / (q0 / mask.dims[0])) : maskd, od + o * output.strides[0], scale, + maskd + (o / (q0 / batch)) * maskStride, od + o * output.strides[0], scale, q1, q2, k1, v2)); } for (int o = 0; o < futures.size(); o++) { diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 4ec2362a..181ed02c 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -1948,8 +1948,8 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const float *vd = (float*)v.cudaData; float *maskd = mask.dims.size() > 0 ? (float*)mask.cudaData : nullptr; float *od = (float*)output.cudaData; - int batch = mask.dims.size() > 0 ? mask.dims[0] : 1; - + int batch = (mask.dims.size() == 3) ? mask.dims[0] : 1; + int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)); if (false) { float *qk = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float)); float *temp = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float)); @@ -1968,6 +1968,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const auto fastllmCublasHandle = getFastllmCublasHandle(); cublasStatus_t status; + for (int i = 0; i < q0; i++) { status = cublasSgemmStridedBatched(fastllmCublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, @@ -1984,7 +1985,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const } if (maskd) { - SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * q1 * k1, -10000, q1 * k1); + SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, -10000, q1 * k1); } int outer = q1; diff --git a/src/models/chatglm.cpp b/src/models/chatglm.cpp index e35d9a94..0de53194 100644 --- a/src/models/chatglm.cpp +++ b/src/models/chatglm.cpp @@ -233,15 +233,14 @@ namespace fastllm { std::vector outputSize = {q.dims[1], q.dims[2], q.dims[0], pastKey.dims[1]}; q.Reshape({q.dims[0], q.dims[1] * q.dims[2], q.dims[3]}); PermuteSelf(q, {1, 0, 2}); - //Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1); - + Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1); +/* // 1.2 Attention // 1.2.0 q * k^T q.Reshape({pastKey.dims[0], -1, q.dims[2]}); MatMulTransB(q, pastKey, attnProbs, 1.0 / (scale_attn * (i + 1))); attnProbs.Reshape(outputSize); - // 1.2.1 Mask if (attentionMask.dims.size() != 0) { AttentionMask(attnProbs, attentionMask, -10000); @@ -255,6 +254,8 @@ namespace fastllm { attnProbs.Reshape({pastValue.dims[0], -1, attnProbs.dims[2]}); MatMul(attnProbs, pastValue, contextLayer); +*/ + contextLayer.Reshape({batch, num_attention_heads, maxLen, -1}); PermuteSelf(contextLayer, {2, 0, 1, 3}); contextLayer.Reshape({contextLayer.dims[0], contextLayer.dims[1], embed_dim});