Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 8, 2024
1 parent 2619759 commit f195cb4
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -856,9 +856,10 @@ __global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer) {
}

template <typename T, int THREAD_PER_BLOCK>
__global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer, int outer, int channels) {
__global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer, int outer) {
int o = blockIdx.x;
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> ((T*)pointer[o / outer] + (o % outer) * channels, (T*)pointer[o / outer] + (o % outer) * channels,
int channels = (int)((size_t)pointer[o / outer * 2 + 1]);
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> ((T*)pointer[o / outer * 2] + (o % outer) * channels, (T*)pointer[o / outer * 2] + (o % outer) * channels,
channels, nullptr, nullptr);
}

Expand Down Expand Up @@ -3796,17 +3797,27 @@ bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::
}

if (true) {
int outer = q[0]->dims[0] * q[0]->dims[1], channels = k[0]->dims[1];
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch);
cudaMemcpy(pointers, qk, sizeof(uint8_t*) * batch, cudaMemcpyHostToDevice);
if (channels < 128) {
FastllmSoftmaxKernelBatchInner1 <T, 32> <<<batch * outer, 32>>> (pointers, outer, channels);
} else if (channels < 512) {
FastllmSoftmaxKernelBatchInner1 <T, 64> <<<batch * outer, 64>>> (pointers, outer, channels);
int outer = q[0]->dims[0] * q[0]->dims[1];
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * 2);
uint8_t ** cpuPointers = new uint8_t*[batch * 2];
int maxChannels = 0;
for (int b = 0; b < batch; b++) {
int outer = q[b]->dims[0] * q[b]->dims[1];
int channels = k[b]->dims[1];
cpuPointers[b * 2 + 0] = (uint8_t*)(qk[b]);
cpuPointers[b * 2 + 1] = (uint8_t*)((size_t)channels);
maxChannels = max(maxChannels, channels);
}
cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * batch * 2, cudaMemcpyHostToDevice);
if (maxChannels < 128) {
FastllmSoftmaxKernelBatchInner1 <T, 32> <<<batch * outer, 32>>> (pointers, outer);
} else if (maxChannels < 512) {
FastllmSoftmaxKernelBatchInner1 <T, 64> <<<batch * outer, 64>>> (pointers, outer);
} else {
FastllmSoftmaxKernelBatchInner1 <T, 128> <<<batch * outer, 128>>> (pointers, outer, channels);
FastllmSoftmaxKernelBatchInner1 <T, 128> <<<batch * outer, 128>>> (pointers, outer);
}
FastllmCudaFree(pointers);
delete[] cpuPointers;
}

if (true) {
Expand Down

0 comments on commit f195cb4

Please sign in to comment.