Skip to content

Commit

Permalink
提高超长文本响应速度
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 1, 2024
1 parent dc8e853 commit 3080aa1
Showing 1 changed file with 118 additions and 83 deletions.
201 changes: 118 additions & 83 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,9 @@ __global__ void SimpleMask(half* a, half *b, half maskValue, int spatial) {
}

template <int THREAD_PER_BLOCK, typename T>
__global__ void CausalMask(T* a, T maskValue, int q, int k) {
__global__ void CausalMask(T* a, T maskValue, int q, int k, int base) {
a += blockIdx.x * k;
for (int i = k - q + blockIdx.x + threadIdx.x + 1; i < k; i += THREAD_PER_BLOCK) {
for (int i = base + blockIdx.x + threadIdx.x + 1; i < k; i += THREAD_PER_BLOCK) {
a[i] = maskValue;
}
}
Expand Down Expand Up @@ -655,8 +655,35 @@ __global__ void FastllmRotatePosition2DKernel(float *data, float *positionIds, f
d[i * m + m / 4] = va * curSin + vb * curCos;
}

__global__ void InitBlockAtten(float *sum0, float *max0, float *sum1, float *max1, int len) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i < len) {
sum0[i] = sum1[i] = 0.0f;
max0[i] = max1[i] = -10000.0f;
}
}

template <int THREAD_PER_BLOCK>
__device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int channels) {
__global__ void AttnBlockUpdate(half *data, int n, int m, float *lastMax, float *lastSum, float *curMax, float *curSum) {
__shared__ float scale;
unsigned int tid = threadIdx.x;
unsigned int bid = blockIdx.x;

if (tid == 0) {
float oldSum = lastSum[bid] * exp(lastMax[bid] - curMax[bid]);
scale = oldSum / curSum[bid];
lastSum[bid] = curSum[bid];
lastMax[bid] = curMax[bid];
}
__syncthreads();

for (int i = tid; i < m; i += THREAD_PER_BLOCK) {
data[bid * m + tid] = (half)((float)data[bid * m + tid] * scale);
}
}

template <int THREAD_PER_BLOCK>
__device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int channels, float *maxp, float *sump) {
__shared__ float sdata[THREAD_PER_BLOCK];
__shared__ float maxV;

Expand All @@ -680,6 +707,9 @@ __device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int
// 3. 记录max
if (tid == 0) {
maxV = sdata[0];
if (maxp != nullptr) {
maxp[0] = sdata[0];
}
}
__syncthreads();

Expand All @@ -700,7 +730,10 @@ __device__ void FastllmSoftmaxKernelInner1Func(float *input, float *output, int
}
if (tid == 0) {
if (fabs(sdata[0]) < 1e-6) {
sdata[0] = 0.1;
sdata[0] = 0.0001;
}
if (sump != nullptr) {
sump[0] = sdata[0];
}
}
__syncthreads();
Expand All @@ -723,7 +756,7 @@ __device__ half FastllmHalfMaxFunc(const __half a, const __half b) {
}

template <int THREAD_PER_BLOCK>
__device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int channels) {
__device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int channels, float *maxp, float *sump) {
__shared__ float sdata[THREAD_PER_BLOCK];

// 1. 每个线程计算一部分
Expand All @@ -744,6 +777,12 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch
}

// 3. 记录max
if (tid == 0) {
if (maxp != nullptr) {
sdata[0] = max(maxp[0], sdata[0]);
}
}
__syncthreads();
float maxV = sdata[0];

// 4. 求和
Expand All @@ -762,7 +801,12 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch
}
if (tid == 0) {
if (fabs(sdata[0]) < 1e-6) {
sdata[0] = 0.1;
sdata[0] = 0.0001;
}
if (sump != nullptr) {
sump[0] = sump[0] * exp(maxp[0] - maxV) + sdata[0];
sdata[0] = sump[0];
maxp[0] = maxV;
}
}
__syncthreads();
Expand All @@ -776,26 +820,38 @@ __device__ void FastllmSoftmaxKernelInner1Func(half *input, half *output, int ch
template <int THREAD_PER_BLOCK>
__global__ void FastllmSoftmaxKernelInner1(float* input, float *output, int outer, int channels) {
int o = blockIdx.x;
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, channels);
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, channels, nullptr, nullptr);
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmSoftmaxKernelInner1(half* input, half *output, int outer, int channels) {
int o = blockIdx.x;
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, channels);
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, channels, nullptr, nullptr);
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmSoftmaxKernelInner1(half* input, half *output, int outer, int channels, float *maxp, float *sump) {
int o = blockIdx.x;
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, channels, maxp + o, sump + o);
}

template <int THREAD_PER_BLOCK, typename T>
__global__ void FastllmSoftmaxKernelInner1WithCausalMask(T* input, T *output, int outer, int channels, int base) {
int o = blockIdx.x;
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, o + base + 1);
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, o + base + 1, nullptr, nullptr);
}

template <int THREAD_PER_BLOCK, typename T>
__global__ void FastllmSoftmaxKernelInner1WithCausalMask(T* input, T *output, int outer, int channels, int base, float *maxp, float *sump) {
int o = blockIdx.x;
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, min(channels, o + base + 1), maxp + o, sump + o);
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer) {
int o = blockIdx.x;
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> ((float*)pointer[o * 3], (float*)pointer[o * 3 + 1],
(int)((size_t)pointer[o * 3 + 2]));
(int)((size_t)pointer[o * 3 + 2]), nullptr, nullptr);
}

template <int THREAD_PER_BLOCK>
Expand Down Expand Up @@ -1620,46 +1676,6 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) {
*/
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmAttentionKernel(float *qd, float *kd, float *vd, float *maskd, float *od,
float scale, int q1, int q2, int k1, int v2,
int group, int qstride, int kstride, int vstride, int ostride,
float *qk, float *temp) {
int o = blockIdx.x;
qd += o * qstride;
kd += (o / group) * kstride;
vd += (o / group) * vstride;
od += o * ostride;
qk += o * k1;
temp += o * k1;
for (int i = 0; i < q1; i++) {
for (int j = threadIdx.x; j < k1; j += THREAD_PER_BLOCK) {
if (maskd && maskd[i * k1 + j] > 0.99) {
qk[j] = -10000;
continue;
}
float sum = 0.0f;
float *tempQd = qd + i * q2, *tempKd = kd + j * q2;
for (int l = 0; l < q2; l++) {
sum += tempQd[l] * tempKd[l];
}
qk[j] = sum * scale;
}
__syncthreads();
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (qk, temp, k1);
__syncthreads();
for (int j = threadIdx.x; j < v2; j += THREAD_PER_BLOCK) {
float *curInput1 = vd + j;
float sum = 0.0;
for (int l = 0; l < k1; l++) {
sum += temp[l] * curInput1[l * v2];
}
od[i * v2 + j] = sum;
}
__syncthreads();
}
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmAttentionBatchKernel(float** pointer, float scale, int group) {
const int params = 16;
Expand Down Expand Up @@ -2992,17 +3008,6 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
float *od = (float*)output.cudaData;
int batch = (mask.dims.size() == 3) ? mask.dims[0] : 1;
int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0));
if (false) {
float *qk = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float));
float *temp = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float));
FastllmAttentionKernel<256> <<<q0, 256>>>(qd, kd, vd, maskd, od,
scale, q1, q2, k1, v2,
group, q.strides[0], k.strides[0], v.strides[0], output.strides[0],
qk, temp);
FastllmCudaFree(qk);
FastllmCudaFree(temp);
return true;
}

if (q1 >= 1024 || (q1 > 1 && q1 != k1)) {
float *qk = (float *) FastllmCudaMalloc(q1 * k1 * sizeof(float));
Expand All @@ -3027,7 +3032,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
}

if (batch == 1 && maskd == nullptr) {
CausalMask<256, float> <<<q1, 256>>>(qk, 0, q1, k1);
CausalMask<256, float> <<<q1, 256>>>(qk, 0, q1, k1, k1 - q1);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1, k1 - q1);
} else {
if (maskd) {
Expand Down Expand Up @@ -3143,37 +3148,67 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co
alignK1 = ((k1 - 1) / 128 + 1) * 128;
}

half *qk = (half *) FastllmCudaMalloc(alignQ1 * alignK1 * sizeof(half));
cudaMemset(qk, 0, alignQ1 * alignK1 * sizeof(half));

int part = (alignK1 > 8192 ? 8192 : alignK1);
half *qk = (half *) FastllmCudaMalloc(alignQ1 * part * sizeof(half));
cudaMemset(qk, 0, alignQ1 * part * sizeof(half));
auto fastllmCublasHandle = getFastllmCublasHandle();
cublasStatus_t status;
for (int i = 0; i < q0; i++) {
//DeviceSync();
//auto st = std::chrono::system_clock::now();
if (useFastAttn) {
GpuQK(qd + i * q.Count(1), kd + (i / group) * k.Count(1), qk, alignQ1, alignK1, q2, scale, k1 - q1);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, alignK1, k1 - q1);
status = cublasHgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_N, CUBLAS_OP_N,
v2, q1, alignK1, &one,
vd + (i / group) * v.Count(1), v.strides[1], v.Count(1),
qk, alignK1, alignK1 * alignQ1,
&beta,
od + i * v2 * q1, v2, v2 * q1, 1);
/*
int part = (q1 + 1) / 2;
for (int l = 0; l < q1; l += part) {
int cur = std::min(part, q1 - l);
if (alignK1 > 8192) {
float *lastSum = (float*)FastllmCudaMalloc(alignQ1 * sizeof(float));
float *lastMax = (float*)FastllmCudaMalloc(alignQ1 * sizeof(float));
float *currentSum = (float*)FastllmCudaMalloc(alignQ1 * sizeof(float));
float *currentMax = (float*)FastllmCudaMalloc(alignQ1 * sizeof(float));

int threadPerBlock = std::min(256, alignQ1);
InitBlockAtten <<< (alignQ1 - 1) / threadPerBlock + 1, threadPerBlock>>> (lastSum, lastMax, currentSum, currentMax, alignQ1);

int part = 8192;
for (int st = 0; st < alignK1; st += part) {
int len = std::min(part, alignK1 - st);
status = cublasHgemm(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
len, alignQ1, q2, &hscale,
kd + (i / group) * k.Count(1) + st * k.strides[1], k.strides[1],
qd + i * q.Count(1), q.strides[1],
&beta,
qk, len);
CausalMask<256, half> <<<q1, 256>>>(qk, __float2half_rn(0.0f), alignQ1, len, k1 - q1 - st);
FastllmSoftmaxKernelInner1WithCausalMask<256> <<< q1, 256 >>>(qk, qk, alignQ1, len, k1 - q1 - st, currentMax, currentSum);
if (st > 0) {
AttnBlockUpdate <128> <<< alignQ1, 128 >>> (od + i * v2 * q1, alignQ1, v2, lastMax, lastSum, currentMax, currentSum);
} else {
cudaMemcpy(lastMax, currentMax, alignQ1 * sizeof(float), cudaMemcpyDeviceToDevice);
cudaMemcpy(lastSum, currentSum, alignQ1 * sizeof(float), cudaMemcpyDeviceToDevice);
}
half currentScale = __float2half_rn(st > 0 ? 1.0f : 0.0f);
status = cublasHgemm(fastllmCublasHandle,
CUBLAS_OP_N, CUBLAS_OP_N,
v2, alignQ1, len, &one,
vd + (i / group) * v.Count(1) + st * v.strides[1], v.strides[1],
qk, len,
&currentScale,
od + i * v2 * q1, v2);
}

FastllmCudaFree(lastSum);
FastllmCudaFree(lastMax);
FastllmCudaFree(currentSum);
FastllmCudaFree(currentMax);
} else {
GpuQK(qd + i * q.Count(1), kd + (i / group) * k.Count(1), qk, alignQ1, alignK1, q2, scale, k1 - q1);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, alignK1, k1 - q1);
status = cublasHgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_N, CUBLAS_OP_N,
v2, cur, min(k1, (l + cur)), &one,
v2, q1, alignK1, &one,
vd + (i / group) * v.Count(1), v.strides[1], v.Count(1),
qk + l * alignK1, alignK1, alignK1 * alignQ1,
qk, alignK1, alignK1 * alignQ1,
&beta,
od + i * v2 * q1 + l * v2, v2, v2 * q1, 1);
od + i * v2 * q1, v2, v2 * q1, 1);
}
*/
} else {
status = cublasHgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
Expand All @@ -3190,7 +3225,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co
}

if (batch == 1 && maskd == nullptr) {
CausalMask<256, half> <<<q1, 256>>>(qk, __float2half_rn(0), q1, k1);
CausalMask<256, half> <<<q1, 256>>>(qk, __float2half_rn(0), q1, k1, k1 - q1);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1, k1 - q1);
} else {
SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, __float2half_rn(-10000), q1 * k1);
Expand Down

0 comments on commit 3080aa1

Please sign in to comment.