diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index fcd5a87f..914c05a7 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -505,9 +505,9 @@ __global__ void SimpleMask(half* a, half *b, half maskValue, int spatial) { } template -__global__ void CausalMask(T* a, T maskValue, int q, int k) { +__global__ void CausalMask(T* a, T maskValue, int q, int k, int base) { a += blockIdx.x * k; - for (int i = k - q + blockIdx.x + threadIdx.x + 1; i < k; i += THREAD_PER_BLOCK) { + for (int i = base + blockIdx.x + threadIdx.x + 1; i < k; i += THREAD_PER_BLOCK) { a[i] = maskValue; } } @@ -655,8 +655,35 @@ __global__ void FastllmRotatePosition2DKernel(float *data, float *positionIds, f d[i * m + m / 4] = va * curSin + vb * curCos; } +__global__ void InitBlockAtten(float *sum0, float *max0, float *sum1, float *max1, int len) { + int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < len) { + sum0[i] = sum1[i] = 0.0f; + max0[i] = max1[i] = -10000.0f; + } +} + template -__device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int channels) { +__global__ void AttnBlockUpdate(half *data, int n, int m, float *lastMax, float *lastSum, float *curMax, float *curSum) { + __shared__ float scale; + unsigned int tid = threadIdx.x; + unsigned int bid = blockIdx.x; + + if (tid == 0) { + float oldSum = lastSum[bid] * exp(lastMax[bid] - curMax[bid]); + scale = oldSum / curSum[bid]; + lastSum[bid] = curSum[bid]; + lastMax[bid] = curMax[bid]; + } + __syncthreads(); + + for (int i = tid; i < m; i += THREAD_PER_BLOCK) { + data[bid * m + tid] = (half)((float)data[bid * m + tid] * scale); + } +} + +template +__device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int channels, float *maxp, float *sump) { __shared__ float sdata[THREAD_PER_BLOCK]; __shared__ float maxV; @@ -680,6 +707,9 @@ __device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int // 3. 记录max if (tid == 0) { maxV = sdata[0]; + if (maxp != nullptr) { + maxp[0] = sdata[0]; + } } __syncthreads(); @@ -700,7 +730,10 @@ __device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int } if (tid == 0) { if (fabs(sdata[0]) < 1e-6) { - sdata[0] = 0.1; + sdata[0] = 0.0001; + } + if (sump != nullptr) { + sump[0] = sdata[0]; } } __syncthreads(); @@ -723,7 +756,7 @@ __device__ half FastllmHalfMaxFunc(const __half a, const __half b) { } template -__device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int channels) { +__device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int channels, float *maxp, float *sump) { __shared__ float sdata[THREAD_PER_BLOCK]; // 1. 每个线程计算一部分 @@ -744,6 +777,12 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch } // 3. 记录max + if (tid == 0) { + if (maxp != nullptr) { + sdata[0] = max(maxp[0], sdata[0]); + } + } + __syncthreads(); float maxV = sdata[0]; // 4. 求和 @@ -762,7 +801,12 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch } if (tid == 0) { if (fabs(sdata[0]) < 1e-6) { - sdata[0] = 0.1; + sdata[0] = 0.0001; + } + if (sump != nullptr) { + sump[0] = sump[0] * exp(maxp[0] - maxV) + sdata[0]; + sdata[0] = sump[0]; + maxp[0] = maxV; } } __syncthreads(); @@ -776,26 +820,38 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch template __global__ void FastllmSoftmaxKernelInner1(float* input, float *output, int outer, int channels) { int o = blockIdx.x; - FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, channels); + FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, channels, nullptr, nullptr); } template __global__ void FastllmSoftmaxKernelInner1(half* input, half *output, int outer, int channels) { int o = blockIdx.x; - FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, channels); + FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, channels, nullptr, nullptr); +} + +template +__global__ void FastllmSoftmaxKernelInner1(half* input, half *output, int outer, int channels, float *maxp, float *sump) { + int o = blockIdx.x; + FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, channels, maxp + o, sump + o); } template __global__ void FastllmSoftmaxKernelInner1WithCausalMask(T* input, T *output, int outer, int channels, int base) { int o = blockIdx.x; - FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, o + base + 1); + FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, o + base + 1, nullptr, nullptr); +} + +template +__global__ void FastllmSoftmaxKernelInner1WithCausalMask(T* input, T *output, int outer, int channels, int base, float *maxp, float *sump) { + int o = blockIdx.x; + FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, min(channels, o + base + 1), maxp + o, sump + o); } template __global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer) { int o = blockIdx.x; FastllmSoftmaxKernelInner1Func ((float*)pointer[o * 3], (float*)pointer[o * 3 + 1], - (int)((size_t)pointer[o * 3 + 2])); + (int)((size_t)pointer[o * 3 + 2]), nullptr, nullptr); } template @@ -1620,46 +1676,6 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) { */ } -template -__global__ void FastllmAttentionKernel(float *qd, float *kd, float *vd, float *maskd, float *od, - float scale, int q1, int q2, int k1, int v2, - int group, int qstride, int kstride, int vstride, int ostride, - float *qk, float *temp) { - int o = blockIdx.x; - qd += o * qstride; - kd += (o / group) * kstride; - vd += (o / group) * vstride; - od += o * ostride; - qk += o * k1; - temp += o * k1; - for (int i = 0; i < q1; i++) { - for (int j = threadIdx.x; j < k1; j += THREAD_PER_BLOCK) { - if (maskd && maskd[i * k1 + j] > 0.99) { - qk[j] = -10000; - continue; - } - float sum = 0.0f; - float *tempQd = qd + i * q2, *tempKd = kd + j * q2; - for (int l = 0; l < q2; l++) { - sum += tempQd[l] * tempKd[l]; - } - qk[j] = sum * scale; - } - __syncthreads(); - FastllmSoftmaxKernelInner1Func (qk, temp, k1); - __syncthreads(); - for (int j = threadIdx.x; j < v2; j += THREAD_PER_BLOCK) { - float *curInput1 = vd + j; - float sum = 0.0; - for (int l = 0; l < k1; l++) { - sum += temp[l] * curInput1[l * v2]; - } - od[i * v2 + j] = sum; - } - __syncthreads(); - } -} - template __global__ void FastllmAttentionBatchKernel(float** pointer, float scale, int group) { const int params = 16; @@ -2992,17 +3008,6 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const float *od = (float*)output.cudaData; int batch = (mask.dims.size() == 3) ? mask.dims[0] : 1; int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0)); - if (false) { - float *qk = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float)); - float *temp = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float)); - FastllmAttentionKernel<256> <<>>(qd, kd, vd, maskd, od, - scale, q1, q2, k1, v2, - group, q.strides[0], k.strides[0], v.strides[0], output.strides[0], - qk, temp); - FastllmCudaFree(qk); - FastllmCudaFree(temp); - return true; - } if (q1 >= 1024 || (q1 > 1 && q1 != k1)) { float *qk = (float *) FastllmCudaMalloc(q1 * k1 * sizeof(float)); @@ -3027,7 +3032,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const } if (batch == 1 && maskd == nullptr) { - CausalMask<256, float> <<>>(qk, 0, q1, k1); + CausalMask<256, float> <<>>(qk, 0, q1, k1, k1 - q1); FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1, k1 - q1); } else { if (maskd) { @@ -3143,37 +3148,67 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co alignK1 = ((k1 - 1) / 128 + 1) * 128; } - half *qk = (half *) FastllmCudaMalloc(alignQ1 * alignK1 * sizeof(half)); - cudaMemset(qk, 0, alignQ1 * alignK1 * sizeof(half)); - + int part = (alignK1 > 8192 ? 8192 : alignK1); + half *qk = (half *) FastllmCudaMalloc(alignQ1 * part * sizeof(half)); + cudaMemset(qk, 0, alignQ1 * part * sizeof(half)); auto fastllmCublasHandle = getFastllmCublasHandle(); cublasStatus_t status; for (int i = 0; i < q0; i++) { //DeviceSync(); //auto st = std::chrono::system_clock::now(); if (useFastAttn) { - GpuQK(qd + i * q.Count(1), kd + (i / group) * k.Count(1), qk, alignQ1, alignK1, q2, scale, k1 - q1); - FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, alignK1, k1 - q1); - status = cublasHgemmStridedBatched(fastllmCublasHandle, - CUBLAS_OP_N, CUBLAS_OP_N, - v2, q1, alignK1, &one, - vd + (i / group) * v.Count(1), v.strides[1], v.Count(1), - qk, alignK1, alignK1 * alignQ1, - &beta, - od + i * v2 * q1, v2, v2 * q1, 1); -/* - int part = (q1 + 1) / 2; - for (int l = 0; l < q1; l += part) { - int cur = std::min(part, q1 - l); + if (alignK1 > 8192) { + float *lastSum = (float*)FastllmCudaMalloc(alignQ1 * sizeof(float)); + float *lastMax = (float*)FastllmCudaMalloc(alignQ1 * sizeof(float)); + float *currentSum = (float*)FastllmCudaMalloc(alignQ1 * sizeof(float)); + float *currentMax = (float*)FastllmCudaMalloc(alignQ1 * sizeof(float)); + + int threadPerBlock = std::min(256, alignQ1); + InitBlockAtten <<< (alignQ1 - 1) / threadPerBlock + 1, threadPerBlock>>> (lastSum, lastMax, currentSum, currentMax, alignQ1); + + int part = 8192; + for (int st = 0; st < alignK1; st += part) { + int len = std::min(part, alignK1 - st); + status = cublasHgemm(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + len, alignQ1, q2, &hscale, + kd + (i / group) * k.Count(1) + st * k.strides[1], k.strides[1], + qd + i * q.Count(1), q.strides[1], + &beta, + qk, len); + CausalMask<256, half> <<>>(qk, __float2half_rn(0.0f), alignQ1, len, k1 - q1 - st); + FastllmSoftmaxKernelInner1WithCausalMask<256> <<< q1, 256 >>>(qk, qk, alignQ1, len, k1 - q1 - st, currentMax, currentSum); + if (st > 0) { + AttnBlockUpdate <128> <<< alignQ1, 128 >>> (od + i * v2 * q1, alignQ1, v2, lastMax, lastSum, currentMax, currentSum); + } else { + cudaMemcpy(lastMax, currentMax, alignQ1 * sizeof(float), cudaMemcpyDeviceToDevice); + cudaMemcpy(lastSum, currentSum, alignQ1 * sizeof(float), cudaMemcpyDeviceToDevice); + } + half currentScale = __float2half_rn(st > 0 ? 1.0f : 0.0f); + status = cublasHgemm(fastllmCublasHandle, + CUBLAS_OP_N, CUBLAS_OP_N, + v2, alignQ1, len, &one, + vd + (i / group) * v.Count(1) + st * v.strides[1], v.strides[1], + qk, len, + ¤tScale, + od + i * v2 * q1, v2); + } + + FastllmCudaFree(lastSum); + FastllmCudaFree(lastMax); + FastllmCudaFree(currentSum); + FastllmCudaFree(currentMax); + } else { + GpuQK(qd + i * q.Count(1), kd + (i / group) * k.Count(1), qk, alignQ1, alignK1, q2, scale, k1 - q1); + FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, alignK1, k1 - q1); status = cublasHgemmStridedBatched(fastllmCublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, - v2, cur, min(k1, (l + cur)), &one, + v2, q1, alignK1, &one, vd + (i / group) * v.Count(1), v.strides[1], v.Count(1), - qk + l * alignK1, alignK1, alignK1 * alignQ1, + qk, alignK1, alignK1 * alignQ1, &beta, - od + i * v2 * q1 + l * v2, v2, v2 * q1, 1); + od + i * v2 * q1, v2, v2 * q1, 1); } -*/ } else { status = cublasHgemmStridedBatched(fastllmCublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, @@ -3190,7 +3225,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co } if (batch == 1 && maskd == nullptr) { - CausalMask<256, half> <<>>(qk, __float2half_rn(0), q1, k1); + CausalMask<256, half> <<>>(qk, __float2half_rn(0), q1, k1, k1 - q1); FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1, k1 - q1); } else { SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, __float2half_rn(-10000), q1 * k1);