From 7fa20dd6d1a62e18c644d951605ffddb82360f1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Fri, 9 Aug 2024 15:23:14 +0800 Subject: [PATCH] =?UTF-8?q?tp=E5=8A=A0=E9=80=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/devices/cuda/fastllm-cuda.cu | 187 +++++++++-------- src/devices/multicuda/fastllm-multicuda.cu | 228 ++++++++++++--------- 2 files changed, 225 insertions(+), 190 deletions(-) diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index a3af28e9..8ba6e019 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -2198,6 +2198,12 @@ void FastllmCudaFinishOutput(fastllm::Data &output, void *data) { DeviceSync(); } +void LaunchFastllmGemmFp32Int8(float *input, uint8_t *weight, float *output, float *bias, float *scales, uint8_t *zeros, int n, int m, int k) { + for (int i = 0; i < n; i++) { + FastllmGemvInt8Kernel2<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, zeros, m, k); + } +} + bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { if (weight.cudaData == nullptr || weight.extraCudaData.size() == 0) { float *cudaScales; @@ -2310,21 +2316,19 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh FastllmCudaFree(cudaFp16Weight); #endif } else { - for (int i = 0; i < n; i++) { - FastllmGemvInt8Kernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m, - (uint8_t *) weight.cudaData, - cudaOutput + i * k, - cudaBiasData, - cudaScales, - cudaZeropoints, - m, k); - } + LaunchFastllmGemmFp32Int8(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaZeropoints, n, m, k); } FastllmCudaFinishInput(input, cudaInput); FastllmCudaFinishOutput(output, cudaOutput); return true; } +void LaunchFastllmGemvInt4Kernel2(float *input, uint8_t *weight, float *output, float *bias, float *scales, uint8_t *zeros, int n, int m, int k) { + for (int i = 0; i < n; i++) { + FastllmGemvInt4Kernel2<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, zeros, m, k); + } +} + bool FastllmCudaMatMulFloatInt4(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { if (weight.cudaData == nullptr || weight.extraCudaData.size() == 0) { float *cudaScales; @@ -2360,21 +2364,19 @@ bool FastllmCudaMatMulFloatInt4(const fastllm::Data &input, fastllm::Data &weigh float *cudaInput = (float*)FastllmCudaPrepareInput(input); float *cudaOutput = (float*)FastllmCudaPrepareOutput(output); + LaunchFastllmGemvInt4Kernel2(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaZeropoints, n, m, k); - for (int i = 0; i < n; i++) { - FastllmGemvInt4Kernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m, - (uint8_t *) weight.cudaData, - cudaOutput + i * k, - cudaBiasData, - cudaScales, - cudaZeropoints, - m, k); - } FastllmCudaFinishInput(input, cudaInput); FastllmCudaFinishOutput(output, cudaOutput); return true; } +void LaunchFastllmGemmFp32Int4Group(float *input, uint8_t *weight, float *output, float *bias, float *scales, float *mins, int n, int m, int k, int group, int groupCnt) { + for (int i = 0; i < n; i++) { + FastllmGemvInt4GroupKernel2<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, mins, m, k, group, groupCnt); + } +} + bool FastllmCudaMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { int group = weight.group, groupCnt = weight.groupCnt; @@ -2460,21 +2462,19 @@ bool FastllmCudaMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data & FastllmCudaFree(cudaFp16Output); FastllmCudaFree(cudaFp16Weight); } else { - for (int i = 0; i < n; i++) { - FastllmGemvInt4GroupKernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m, - (uint8_t *) weight.cudaData, - cudaOutput + i * k, - cudaBiasData, - cudaScales, - cudaMins, - m, k, group, groupCnt); - } + LaunchFastllmGemmFp32Int4Group(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaMins, n, m, k, group, groupCnt); } FastllmCudaFinishInput(input, cudaInput); FastllmCudaFinishOutput(output, cudaOutput); return true; } +void LaunchFastllmGemmFp32Int4NoZero(float *input, uint8_t *weight, float *output, float *bias, float *scales, float *mins, int n, int m, int k) { + for (int i = 0; i < n; i++) { + FastllmGemvInt4NoZeroKernel1<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, mins, m, k); + } +} + bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { if (weight.cudaData == nullptr || weight.extraCudaData.size() == 0) { float *cudaScales; @@ -2590,16 +2590,9 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data FastllmCudaFree(cudaFp16Weight); #endif } else { - for (int i = 0; i < n; i++) { - FastllmGemvInt4NoZeroKernel1<256, 1> <<< k, 256 >>>(cudaInput + i * m, - (uint8_t *) weight.cudaData, - cudaOutput + i * k, - cudaBiasData, - cudaScales, - cudaMins, - m, k); - } + LaunchFastllmGemmFp32Int4NoZero(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaMins, n, m, k); } + FastllmCudaFinishInput(input, cudaInput); FastllmCudaFinishOutput(output, cudaOutput); return true; @@ -2658,6 +2651,27 @@ bool FastllmCudaMatMulFloat32(const fastllm::Data &input, fastllm::Data &weight, return true; } +void LaunchFastllmGemmFp32Fp16(float *input, half *weight, float *output, float *bias, int n, int m, int k) { + if (n == 1) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 2) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 3) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 4) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 5) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 6) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 7) { + FastllmGemvFp32Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else { + printf("Error: LaunchFastllmGemmFp32Fp16: n > 7.\n"); + exit(0); + } +} + bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { if (weight.cudaData == nullptr || weight.extraCudaData.size() == 0) { float *cudaBiasData; @@ -2675,20 +2689,8 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, float *cudaInput = (float*)FastllmCudaPrepareInput(input); float *cudaOutput = (float*)FastllmCudaPrepareOutput(output); - if (n == 1) { - FastllmGemvFp32Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 2) { - FastllmGemvFp32Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 3) { - FastllmGemvFp32Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 4) { - FastllmGemvFp32Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 5) { - FastllmGemvFp32Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 6) { - FastllmGemvFp32Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 7) { - FastllmGemvFp32Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); + if (n < 8) { + LaunchFastllmGemmFp32Fp16(cudaInput, (half*)weight.cudaData, cudaOutput, cudaBiasData, n, m, k); } else { auto fastllmCublasHandle = getFastllmCublasHandle(); //cudaDeviceSynchronize(); @@ -4157,6 +4159,27 @@ bool FastllmCudaBatchMatMulBatch(void **i0s, void **i1s, void **os, return true; } +void LaunchFastllmGemmFp16Fp16(half *input, half *weight, half *output, half *bias, int n, int m, int k) { + if (n == 1) { + FastllmGemvFp16Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 2) { + FastllmGemvFp16Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 3) { + FastllmGemvFp16Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 4) { + FastllmGemvFp16Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 5) { + FastllmGemvFp16Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 6) { + FastllmGemvFp16Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else if (n == 7) { + FastllmGemvFp16Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(input, weight, output, bias, m, k); + } else { + printf("Error: LaunchFastllmGemmFp16Fp16: n > 7.\n"); + exit(0); + } +} + bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { if (weight.cudaData == nullptr || (weight.extraCudaHalfData.size() == 0 && bias.dims.size() > 0)) { @@ -4181,20 +4204,8 @@ bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &wei half *cudaOutput = (half *) FastllmCudaPrepareOutput(output); half *cudaBiasData = bias.dims.size() == 0 ? nullptr : (half *) weight.extraCudaHalfData[0]; - if (n == 1) { - FastllmGemvFp16Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 2) { - FastllmGemvFp16Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 3) { - FastllmGemvFp16Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 4) { - FastllmGemvFp16Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 5) { - FastllmGemvFp16Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 6) { - FastllmGemvFp16Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); - } else if (n == 7) { - FastllmGemvFp16Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k); + if (n < 8) { + LaunchFastllmGemmFp16Fp16(cudaInput, (half*)weight.cudaData, cudaOutput, cudaBiasData, n, m, k); } else { __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); auto fastllmCublasHandle = getFastllmCublasHandle(); @@ -4224,6 +4235,12 @@ bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &wei return true; } +void LaunchFastllmGemmFp16Int8(half *input, uint8_t *weight, half *output, half *bias, float *scales, uint8_t *zeros, int n, int m, int k) { + for (int i = 0; i < n; i++) { + FastllmGemvFp16Int8Kernel2 <256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, zeros, m, k); + } +} + bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) { weight.extraCudaHalfData.push_back((void*)weight.extraCudaData[0]); @@ -4293,15 +4310,7 @@ bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &w FastllmCudaFree(cudaFp16Weight); } else { half *cudaBiasData = bias.dims.size() > 0 ? (half*)weight.extraCudaHalfData[2] : nullptr; - for (int i = 0; i < n; i++) { - FastllmGemvFp16Int8Kernel2 <256, 1> <<< k, 256 >>>(cudaInput + i * m, - (uint8_t *) weight.cudaData, - cudaOutput + i * k, - cudaBiasData, - cudaScales, - cudaZeropoints, - m, k); - } + LaunchFastllmGemmFp16Int8(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaZeropoints, n, m, k); } FastllmCudaFinishInput(input, cudaInput); @@ -4309,6 +4318,12 @@ bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &w return true; } +void LaunchFastllmGemmFp16Int4Group(half *input, uint8_t *weight, half *output, half *bias, float *scales, float *mins, int n, int m, int k, int group, int groupCnt) { + for (int i = 0; i < n; i++) { + FastllmGemvHalfInt4GroupKernel<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, mins, m, k, group, groupCnt); + } +} + bool FastllmCudaHalfMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { int group = weight.group, groupCnt = weight.groupCnt; if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) { @@ -4379,21 +4394,19 @@ bool FastllmCudaHalfMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Da FastllmCudaDirectFree(cudaFp16Weight); } else { half *cudaBiasData = (half*)weight.extraCudaHalfData[2]; - for (int i = 0; i < n; i++) { - FastllmGemvHalfInt4GroupKernel<256, 1> <<< k, 256 >>>(cudaInput + i * m, - (uint8_t *) weight.cudaData, - cudaOutput + i * k, - cudaBiasData, - cudaScales, - cudaMins, - m, k, group, groupCnt); - } + LaunchFastllmGemmFp16Int4Group(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaMins, n, m, k, group, groupCnt); } FastllmCudaFinishInput(input, cudaInput); FastllmCudaFinishOutput(output, cudaOutput); return true; } +void LaunchFastllmGemmFp16Int4NoZero(half *input, uint8_t *weight, half *output, half *bias, float *scales, float *mins, int n, int m, int k) { + for (int i = 0; i < n; i++) { + FastllmGemvFp16Int4NoZeroKernel2<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, mins, m, k); + } +} + bool FastllmCudaHalfMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) { weight.extraCudaHalfData.push_back((void*)weight.extraCudaData[0]); @@ -4463,15 +4476,7 @@ bool FastllmCudaHalfMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::D FastllmCudaDirectFree(cudaFp16Weight); } else { half *cudaBiasData = (half*)weight.extraCudaHalfData[2]; - for (int i = 0; i < n; i++) { - FastllmGemvFp16Int4NoZeroKernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m, - (uint8_t *) weight.cudaData, - cudaOutput + i * k, - cudaBiasData, - cudaScales, - cudaMins, - m, k); - } + LaunchFastllmGemmFp16Int4NoZero(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaMins, n, m, k); } FastllmCudaFinishInput(input, cudaInput); FastllmCudaFinishOutput(output, cudaOutput); diff --git a/src/devices/multicuda/fastllm-multicuda.cu b/src/devices/multicuda/fastllm-multicuda.cu index 2c2320be..3e0a6eea 100644 --- a/src/devices/multicuda/fastllm-multicuda.cu +++ b/src/devices/multicuda/fastllm-multicuda.cu @@ -39,6 +39,16 @@ extern void *FastllmCudaPrepareOutput(fastllm::Data &output); extern void FastllmCudaFinishInput(const fastllm::Data &input, void *data); extern void FastllmCudaFinishOutput(fastllm::Data &output, void *data); +extern void LaunchFastllmGemmFp16Fp16(half *input, half *weight, half *output, half *bias, int n, int m, int k); +extern void LaunchFastllmGemmFp16Int8(half *input, uint8_t *weight, half *output, half *bias, float *scales, uint8_t *zeros, int n, int m, int k); +extern void LaunchFastllmGemmFp16Int4NoZero(half *input, uint8_t *weight, half *output, half *bias, float *scales, float *mins, int n, int m, int k); +extern void LaunchFastllmGemmFp16Int4Group(half *input, uint8_t *weight, half *output, half *bias, float *scales, float *mins, int n, int m, int k, int group, int groupCnt); + +extern void LaunchFastllmGemmFp32Fp16(float *input, half *weight, float *output, float *bias, int n, int m, int k); +extern void LaunchFastllmGemmFp32Int8(float *input, uint8_t *weight, float *output, float *bias, float *scales, uint8_t *zeros, int n, int m, int k); +extern void LaunchFastllmGemmFp32Int4NoZero(float *input, uint8_t *weight, float *output, float *bias, float *scales, float *mins, int n, int m, int k); +extern void LaunchFastllmGemmFp32Int4Group(float *input, uint8_t *weight, float *output, float *bias, float *scales, float *mins, int n, int m, int k, int group, int groupCnt); + std::vector multiCudaCurrentDevices; void FastllmMultiCudaSetDevice(std::vector ids) { @@ -74,45 +84,59 @@ namespace fastllm { cudaMemcpy(curInput, cudaInput, n * m * sizeof(half), cudaMemcpyDeviceToDevice); } - __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); - auto fastllmCublasHandle = getFastllmCublasHandle(); - cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; - cublasStatus_t status; - - half *fp16Weight = (half*)weight; - bool isQuant = false; - if (weightDataType == DataType::INT4_NOZERO) { - int threadPerBlock = std::min(256, len * m); - isQuant = true; - fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); - FastllmCudaInt42HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, mins, fp16Weight, len * m, m); - } else if (weightDataType == DataType::INT4_GROUP) { - int threadPerBlock = std::min(256, len * m); - isQuant = true; - fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); - FastllmCudaInt4Group2HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, mins, fp16Weight, len * m, m, group, groupCnt); - } else if (weightDataType == DataType::INT8) { - int threadPerBlock = std::min(256, len * m); - isQuant = true; - fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); - FastllmCudaInt82HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, zeros, fp16Weight, len * m, m); - } - - status = cublasGemmEx(fastllmCublasHandle, - CUBLAS_OP_T, CUBLAS_OP_N, - len, n, m, &h_alpha, fp16Weight, AType, - m, curInput, BType, - m, &h_beta, - curOutput, CType, - len, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); - if (status != CUBLAS_STATUS_SUCCESS) { - printf("Error: cublas error.\n"); - throw ("cublas error"); - exit(0); - } - - if (hasBias) { - FastllmCudaBiasKernel <<< n, 256 >>>(curOutput, bias, len); + if (weightDataType == DataType::FLOAT16 && n < 8) { + LaunchFastllmGemmFp16Fp16(curInput, (half*)weight, curOutput, bias, n, m, len); + } else if (weightDataType == DataType::INT8 && n < 8) { + LaunchFastllmGemmFp16Int8(curInput, (uint8_t*)weight, curOutput, bias, scales, zeros, n, m, len); + } else if (weightDataType == DataType::INT4_NOZERO && n < 8) { + LaunchFastllmGemmFp16Int4NoZero(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len); + } else if (weightDataType == DataType::INT4_GROUP && n < 8) { + LaunchFastllmGemmFp16Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, k, group, groupCnt); + } else { + __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); + auto fastllmCublasHandle = getFastllmCublasHandle(); + cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; + cublasStatus_t status; + + half *fp16Weight = (half*)weight; + bool isQuant = false; + if (weightDataType == DataType::INT4_NOZERO) { + int threadPerBlock = std::min(256, len * m); + isQuant = true; + fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); + FastllmCudaInt42HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, mins, fp16Weight, len * m, m); + } else if (weightDataType == DataType::INT4_GROUP) { + int threadPerBlock = std::min(256, len * m); + isQuant = true; + fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); + FastllmCudaInt4Group2HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, mins, fp16Weight, len * m, m, group, groupCnt); + } else if (weightDataType == DataType::INT8) { + int threadPerBlock = std::min(256, len * m); + isQuant = true; + fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); + FastllmCudaInt82HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, zeros, fp16Weight, len * m, m); + } + + status = cublasGemmEx(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + len, n, m, &h_alpha, fp16Weight, AType, + m, curInput, BType, + m, &h_beta, + curOutput, CType, + len, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("Error: cublas error.\n"); + throw ("cublas error"); + exit(0); + } + + if (hasBias) { + FastllmCudaBiasKernel <<< n, 256 >>>(curOutput, bias, len); + } + + if (isQuant) { + FastllmCudaFree(fp16Weight); + } } if (deviceId != 0 || n > 1) { @@ -120,10 +144,6 @@ namespace fastllm { FastllmCudaFree(curInput); FastllmCudaFree(curOutput); } - - if (isQuant) { - FastllmCudaFree(fp16Weight); - } } }; @@ -154,68 +174,78 @@ namespace fastllm { cudaMemcpy(curInput, cudaInput, n * m * sizeof(float), cudaMemcpyDeviceToDevice); } - auto fastllmCublasHandle = getFastllmCublasHandle(); - half *cudaFp16Input, *cudaFp16Output; - cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half)); - cudaFp16Output = (half *) FastllmCudaMalloc(n * len * sizeof(half)); - - __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); - cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; - cublasStatus_t status; - - int threadPerBlock = std::min(256, n * m); - FastllmCudaFloat2HalfKernel <<< (n * m - 1) / threadPerBlock + 1, threadPerBlock>>>(curInput, cudaFp16Input, n * m); - - half *fp16Weight = (half*)weight; - bool isQuant = false; - if (weightDataType == DataType::INT4_NOZERO) { - int threadPerBlock = std::min(256, len * m); - isQuant = true; - fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); - FastllmCudaInt42HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, mins, fp16Weight, len * m, m); - } else if (weightDataType == DataType::INT4_GROUP) { - int threadPerBlock = std::min(256, len * m); - isQuant = true; - fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); - FastllmCudaInt4Group2HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, mins, fp16Weight, len * m, m, group, groupCnt); - } else if (weightDataType == DataType::INT8) { - int threadPerBlock = std::min(256, len * m); - isQuant = true; - fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); - FastllmCudaInt82HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, zeros, fp16Weight, len * m, m); - } - - status = cublasGemmEx(fastllmCublasHandle, - CUBLAS_OP_T, CUBLAS_OP_N, - len, n, m, - &h_alpha, fp16Weight, AType, - m, cudaFp16Input, BType, - m, &h_beta, - cudaFp16Output, CType, - len, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); - if (status != CUBLAS_STATUS_SUCCESS) { - printf("Error: cublas error.\n"); - throw("cublas error"); - exit(0); - } - - FastllmCudaHalf2FloatKernel <<< (n * len - 1) / threadPerBlock + 1, threadPerBlock >>>(cudaFp16Output, curOutput, n * len); - if (hasBias) { - FastllmCudaBiasKernel <<< n, 256 >>> (curOutput, bias, len); + + if (weightDataType == DataType::FLOAT16 && n < 8) { + LaunchFastllmGemmFp32Fp16(curInput, (half*)weight, curOutput, bias, n, m, len); + } else if (weightDataType == DataType::INT8 && n < 8) { + LaunchFastllmGemmFp32Int8(curInput, (uint8_t*)weight, curOutput, bias, scales, zeros, n, m, len); + } else if (weightDataType == DataType::INT4_NOZERO && n < 8) { + LaunchFastllmGemmFp32Int4NoZero(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len); + } else if (weightDataType == DataType::INT4_GROUP && n < 8) { + LaunchFastllmGemmFp32Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, k, group, groupCnt); + } else { + auto fastllmCublasHandle = getFastllmCublasHandle(); + half *cudaFp16Input, *cudaFp16Output; + cudaFp16Input = (half *) FastllmCudaMalloc(n * m * sizeof(half)); + cudaFp16Output = (half *) FastllmCudaMalloc(n * len * sizeof(half)); + + __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); + cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; + cublasStatus_t status; + + int threadPerBlock = std::min(256, n * m); + FastllmCudaFloat2HalfKernel <<< (n * m - 1) / threadPerBlock + 1, threadPerBlock>>>(curInput, cudaFp16Input, n * m); + + half *fp16Weight = (half*)weight; + bool isQuant = false; + if (weightDataType == DataType::INT4_NOZERO) { + int threadPerBlock = std::min(256, len * m); + isQuant = true; + fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); + FastllmCudaInt42HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, mins, fp16Weight, len * m, m); + } else if (weightDataType == DataType::INT4_GROUP) { + int threadPerBlock = std::min(256, len * m); + isQuant = true; + fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); + FastllmCudaInt4Group2HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, mins, fp16Weight, len * m, m, group, groupCnt); + } else if (weightDataType == DataType::INT8) { + int threadPerBlock = std::min(256, len * m); + isQuant = true; + fp16Weight = (half*)FastllmCudaMalloc(len * m * sizeof(half)); + FastllmCudaInt82HalfKernel <<< (len * m - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight, scales, zeros, fp16Weight, len * m, m); + } + + status = cublasGemmEx(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + len, n, m, + &h_alpha, fp16Weight, AType, + m, cudaFp16Input, BType, + m, &h_beta, + cudaFp16Output, CType, + len, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("Error: cublas error.\n"); + throw("cublas error"); + exit(0); + } + + FastllmCudaHalf2FloatKernel <<< (n * len - 1) / threadPerBlock + 1, threadPerBlock >>>(cudaFp16Output, curOutput, n * len); + if (hasBias) { + FastllmCudaBiasKernel <<< n, 256 >>> (curOutput, bias, len); + } + + FastllmCudaFree(cudaFp16Input); + FastllmCudaFree(cudaFp16Output); + + if (isQuant) { + FastllmCudaFree(fp16Weight); + } } - - FastllmCudaFree(cudaFp16Input); - FastllmCudaFree(cudaFp16Output); - if (deviceId != 0 || n > 1) { cudaMemcpy2D(cudaOutput + start, k * sizeof(float), curOutput, len * sizeof(float), len * sizeof(float), n, cudaMemcpyDeviceToDevice); FastllmCudaFree(curInput); FastllmCudaFree(curOutput); } - - if (isQuant) { - FastllmCudaFree(fp16Weight); - } } }; }