diff --git a/include/devices/cuda/fastllm-cuda.cuh b/include/devices/cuda/fastllm-cuda.cuh index b871fbd4..57e9c734 100644 --- a/include/devices/cuda/fastllm-cuda.cuh +++ b/include/devices/cuda/fastllm-cuda.cuh @@ -82,6 +82,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co const fastllm::Data &mask, const fastllm::Data &output, int group, float scale); bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k); bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k); +bool FastllmCudaHalfMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k); void FastllmCudaSetDevice(int gpu_id); #ifdef __cplusplus diff --git a/src/devices/cuda/cudadevice.cpp b/src/devices/cuda/cudadevice.cpp index 5b51f00c..4be1f7df 100644 --- a/src/devices/cuda/cudadevice.cpp +++ b/src/devices/cuda/cudadevice.cpp @@ -296,7 +296,9 @@ namespace fastllm { FastllmCudaHalfMatMulFloat16(input, weight, bias, output, n, m, k); } else if (weight.dataType == DataType::INT8){ FastllmCudaHalfMatMulFloatInt8(input, weight, bias, output, n, m, k); - } else { + } else if (weight.dataType == DataType::INT4_GROUP){ + FastllmCudaHalfMatMulFloatInt4Group(input, weight, bias, output, n, m, k); + }else { ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); } } else if (input.dataType == DataType::FLOAT32) { diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 914c05a7..9616605b 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -3886,6 +3886,91 @@ bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &w return true; } +bool FastllmCudaHalfMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { + int group = weight.group, groupCnt = weight.groupCnt; + if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) { + float *cudaScales; + cudaError_t state = cudaSuccess; + state = cudaMalloc(&cudaScales, k * group * sizeof(float)); + state = cudaMemcpy(cudaScales, weight.scales.data(), k * group * sizeof(float), cudaMemcpyHostToDevice); + weight.extraCudaHalfData.push_back((void*)cudaScales); + + float *cudaMins; + state = cudaMalloc(&cudaMins, k * group * sizeof(float)); + float *mins = new float[k * group]; + for (int i = 0; i < k * group; i++) { + mins[i] = weight.perChannelsConfigs[i].min; + } + state = cudaMemcpy(cudaMins, mins, k * group * sizeof(float), cudaMemcpyHostToDevice); + delete[] mins; + weight.extraCudaHalfData.push_back((void*)cudaMins); + + half *cudaBiasData; + state = cudaMalloc(&cudaBiasData, k * sizeof(half)); + if (bias.dims.size() > 0) { + float *tempBiasData; + state = cudaMalloc(&tempBiasData, k * sizeof(float)); + state = cudaMemcpy(tempBiasData, (uint8_t*)bias.cudaData, k * sizeof(float), cudaMemcpyDeviceToDevice); + int threadPerBlock = std::min(256, k); + FastllmCudaFloat2HalfKernel <<< (k - 1) / threadPerBlock + 1, threadPerBlock>>>(tempBiasData, cudaBiasData, k); + state = cudaFree(tempBiasData); + } else { + state = cudaMemset(cudaBiasData, 0, k * sizeof(half)); + } + checkCudaErrors("Error: CUDA error when moving bias to device!", state); + weight.extraCudaHalfData.push_back((void*)cudaBiasData); + } + float *cudaScales = (float*)weight.extraCudaHalfData[0]; + float *cudaMins = (float*)weight.extraCudaHalfData[1]; + + half *cudaInput = (half*)FastllmCudaPrepareInput(input); + half *cudaOutput = (half*)FastllmCudaPrepareOutput(output); + + auto fastllmCublasHandle = getFastllmCublasHandle(); + half *cudaFp16Weight; + + cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half)); + + __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); + cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; + cublasStatus_t status; + + int len = n * m; + int threadPerBlock = std::min(256, len); + + len = k * m; + + FastllmCudaInt4Group2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData, + cudaScales, + cudaMins, + cudaFp16Weight, len, m, group, groupCnt); + + status = cublasGemmEx(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + k, n, m, + &h_alpha, cudaFp16Weight, AType, + m, cudaInput, BType, + m, &h_beta, + cudaOutput, CType, + k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); + + if (status != CUBLAS_STATUS_SUCCESS) { + printf("Error: cublas error.\n"); + throw("cublas error"); + exit(0); + } + + if (bias.dims.size() > 0) { + half *cudaBiasData = (half*)weight.extraCudaHalfData[2]; + FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, cudaBiasData, k); + } + + FastllmCudaFree(cudaFp16Weight); + FastllmCudaFinishInput(input, cudaInput); + FastllmCudaFinishOutput(output, cudaOutput); + return true; +} + void FastllmCudaSetDevice(int gpu_id) { cudaSetDevice(gpu_id); }