diff --git a/src/devices/cuda/cudadevicebatch.cpp b/src/devices/cuda/cudadevicebatch.cpp index 0d4e30e4..61a34377 100644 --- a/src/devices/cuda/cudadevicebatch.cpp +++ b/src/devices/cuda/cudadevicebatch.cpp @@ -378,7 +378,12 @@ namespace fastllm { Data **masks = (Data**)(datas.find("mask")->second); Data **outputs = (Data**)(datas.find("output")->second); - if (qs[0]->dataType == DataType::FLOAT32) { + long long aveLen = 0; + for (int i = 0; i < batch; i++) { + aveLen += ks[i]->dims[1]; + } + aveLen /= batch; + if (qs[0]->dataType == DataType::FLOAT32 || aveLen < 512) { for (int i = 0; i < batch; i++) { outputs[i]->Allocate(); } diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index a62b1f7e..1fae13f7 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -27,6 +27,7 @@ void showError(cudaError_t result, char const* const message, const char* const #define FETCH_FLOAT4(pointer) (reinterpret_cast(&(pointer))[0]) +#define FETCH_FLOAT2(pointer) (reinterpret_cast(&(pointer))[0]) typedef union __align__(16) { uint2 in; @@ -854,6 +855,14 @@ __global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer) { (int)((size_t)pointer[o * 3 + 2]), nullptr, nullptr); } +template +__global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer, int outer, int channels) { + int o = blockIdx.x; + FastllmSoftmaxKernelInner1Func ((T*)pointer[o / outer] + (o % outer) * channels, (T*)pointer[o / outer] + (o % outer) * channels, + channels, nullptr, nullptr); +} + + template __global__ void FastllmRMSNormKernelInner1(float *input, float *weight, float *output, int outer, int channels, float eps) { int o = blockIdx.x; @@ -1670,38 +1679,29 @@ __global__ void FastllmHalfMatMulTransBBatchKernel(uint8_t** pointer, float alph */ int pera = 4, perb = 4; - float cura[4][4], curb[4][4], curc[4][4]; + half cura[4][4], curb[4][4]; + float curc[4][4]; int cnta = (n - 1) / pera + 1, cntb = (k - 1) / perb + 1; for (int taskId = tid; taskId < cnta * cntb; taskId += THREAD_PER_BLOCK) { int taska = taskId / cntb, taskb = taskId % cntb; for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { - cura[i][j] = 0; - curb[i][j] = 0; - curc[i][j] = 0; + curc[i][j] = 0.0f; } } - for (int l = 0; l < m; l += 4) { for (int a = taska * pera; a < (taska + 1) * pera && a < n; a++) { -#pragma unroll - for (int x = 0; x < 4; x++) { - cura[a - taska * pera][x] = (float)input0[a * input0Stride + l + x]; - } + FETCH_FLOAT2(cura[a - taska * pera]) = FETCH_FLOAT2(input0[a * input0Stride + l]); } for (int b = taskb * perb; b < (taskb + 1) * perb && b < k; b++) { -#pragma unroll - for (int x = 0; x < 4; x++) { - curb[b - taskb * perb][x] = (float)input1[b * input1Stride + l + x]; - } + FETCH_FLOAT2(curb[b - taskb * perb]) = FETCH_FLOAT2(input1[b * input1Stride + l]); } -#pragma unroll + for (int i = 0; i < 4; i++) { -#pragma unroll for (int j = 0; j < 4; j++) { #pragma unroll for (int k = 0; k < 4; k++) { - curc[i][j] += cura[i][k] * curb[j][k]; + curc[i][j] += (float)cura[i][k] * (float)curb[j][k]; } } } @@ -3796,29 +3796,17 @@ bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm:: } if (true) { - int total = 0; - for (int b = 0; b < batch; b++) { - int outer = q[b]->dims[0] * q[b]->dims[1]; - total += outer; - } - uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * total * 3); - uint8_t ** cpuPointers = new uint8_t*[total * 3]; - int cur = 0; - for (int b = 0; b < batch; b++) { - int outer = q[b]->dims[0] * q[b]->dims[1]; - int channels = k[b]->dims[1]; - for (int o = 0; o < outer; o++) { - cpuPointers[cur * 3 + 0] = (uint8_t*)(qk[b] + o * channels); - cpuPointers[cur * 3 + 1] = (uint8_t*)(qk[b] + o * channels); - cpuPointers[cur * 3 + 2] = (uint8_t*)((size_t)channels); - cur++; - } + int outer = q[0]->dims[0] * q[0]->dims[1], channels = k[0]->dims[1]; + uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch); + cudaMemcpy(pointers, qk, sizeof(uint8_t*) * batch, cudaMemcpyHostToDevice); + if (channels < 128) { + FastllmSoftmaxKernelBatchInner1 <<>> (pointers, outer, channels); + } else if (channels < 512) { + FastllmSoftmaxKernelBatchInner1 <<>> (pointers, outer, channels); + } else { + FastllmSoftmaxKernelBatchInner1 <<>> (pointers, outer, channels); } - cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * total * 3, cudaMemcpyHostToDevice); - FastllmSoftmaxKernelBatchInner1 <<>> (pointers); - FastllmCudaFree(pointers); - delete[] cpuPointers; } if (true) { diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index d2986045..77109c2e 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -582,7 +582,7 @@ namespace fastllm { if (isPrompt) { cnt += it.second->currentTokens.size(); - if (cnt > 300) { + if (cnt > 1024) { break; } // break;