Skip to content

Commit

Permalink
长文本优化
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 6, 2024
1 parent bde1985 commit 15bcc21
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 250 deletions.
86 changes: 44 additions & 42 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2782,54 +2782,56 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
}

if (q1 > 1024) {
float *qk = (float *) FastllmCudaMalloc(q1 * k1 * sizeof(float));
int perQ1 = std::min(1024, q1);
float *qk = (float *) FastllmCudaMalloc(perQ1 * k1 * sizeof(float));
float beta = 0, one = 1;
auto fastllmCublasHandle = getFastllmCublasHandle();
cublasStatus_t status;


for (int i = 0; i < q0; i++) {
status = cublasSgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k1, q1, q2, &scale,
kd + (i / group) * k.Count(1), k.strides[1], k.Count(1),
qd + i * q.Count(1), q.strides[1], q.Count(1),
&beta,
qk, k1, k1 * q1, 1);
if (status != CUBLAS_STATUS_SUCCESS) {
printf("status = %d\n", (int) status);
printf("Error: cublas error during MatMulTransB in Attention operator.\n");
throw ("cublas error");
exit(0);
}

if (maskd) {
SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, -10000, q1 * k1);
}
for (int q1Start = 0; q1Start < q1; q1Start += perQ1) {
int curQ1 = std::min(perQ1, q1 - q1Start);
status = cublasSgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k1, curQ1, q2, &scale,
kd + (i / group) * k.Count(1), k.strides[1], k.Count(1),
qd + i * q.Count(1) + q1Start * q.Count(2), q.strides[1], q.Count(1),
&beta,
qk, k1, k1 * curQ1, 1);
if (status != CUBLAS_STATUS_SUCCESS) {
printf("status = %d\n", (int) status);
printf("Error: cublas error during MatMulTransB in Attention operator.\n");
throw ("cublas error");
exit(0);
}

int outer = q1;
if (k1 < 8) {
FastllmSoftmaxKernelInner1<1> <<< outer, 1 >>>(qk, qk, outer, k1);
} else if (k1 < 64) {
FastllmSoftmaxKernelInner1<8> <<< outer, 8 >>>(qk, qk, outer, k1);
} else if (k1 < 512) {
FastllmSoftmaxKernelInner1<64> <<< outer, 64 >>>(qk, qk, outer, k1);
} else {
FastllmSoftmaxKernelInner1<256> <<< outer, 256 >>>(qk, qk, outer, k1);
}
if (maskd) {
SimpleMask<256> <<< (curQ1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride + q1Start * k1, -10000, curQ1 * k1);
}

status = cublasSgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_N, CUBLAS_OP_N,
v2, q1, k1, &one,
vd + (i / group) * v.Count(1), v.strides[1], v.Count(1),
qk, k1, k1 * q1,
&beta,
od + i * v2 * q1, v2, v2 * q1, 1);
if (status != CUBLAS_STATUS_SUCCESS) {
printf("status = %d\n", (int) status);
printf("Error: cublas error during MatMul in Attention operator.\n");
throw ("cublas error");
exit(0);
int outer = curQ1;
if (k1 < 8) {
FastllmSoftmaxKernelInner1<1> <<< outer, 1 >>>(qk, qk, outer, k1);
} else if (k1 < 64) {
FastllmSoftmaxKernelInner1<8> <<< outer, 8 >>>(qk, qk, outer, k1);
} else if (k1 < 512) {
FastllmSoftmaxKernelInner1<64> <<< outer, 64 >>>(qk, qk, outer, k1);
} else {
FastllmSoftmaxKernelInner1<256> <<< outer, 256 >>>(qk, qk, outer, k1);
}
status = cublasSgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_N, CUBLAS_OP_N,
v2, curQ1, k1, &one,
vd + (i / group) * v.Count(1), v.strides[1], v.Count(1),
qk, k1, k1 * q1,
&beta,
od + i * v2 * q1 + v2 * q1Start, v2, v2 * q1, 1);
if (status != CUBLAS_STATUS_SUCCESS) {
printf("status = %d\n", (int) status);
printf("Error: cublas error during MatMul in Attention operator.\n");
throw ("cublas error");
exit(0);
}
}
}

Expand Down Expand Up @@ -2909,7 +2911,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co

half beta = __float2half_rn(0.0f), one = __float2half_rn(1.0f), hscale = __float2half_rn(scale);

if (q1 > 1024) {
if (q1 >= 1024) {
half *qk = (half *) FastllmCudaMalloc(q1 * k1 * sizeof(half));
auto fastllmCublasHandle = getFastllmCublasHandle();
cublasStatus_t status;
Expand Down
Loading

0 comments on commit 15bcc21

Please sign in to comment.