diff --git a/include/devices/cuda/fastllm-cuda.cuh b/include/devices/cuda/fastllm-cuda.cuh index bf197b66..b871fbd4 100644 --- a/include/devices/cuda/fastllm-cuda.cuh +++ b/include/devices/cuda/fastllm-cuda.cuh @@ -81,6 +81,7 @@ bool FastllmCudaBatchMatMulBatch(void **i0s, void **i1s, void **os, bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, const fastllm::Data &v, const fastllm::Data &mask, const fastllm::Data &output, int group, float scale); bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k); +bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k); void FastllmCudaSetDevice(int gpu_id); #ifdef __cplusplus diff --git a/src/devices/cuda/cudadevice.cpp b/src/devices/cuda/cudadevice.cpp index a1ff6d31..5b51f00c 100644 --- a/src/devices/cuda/cudadevice.cpp +++ b/src/devices/cuda/cudadevice.cpp @@ -294,6 +294,8 @@ namespace fastllm { if (input.dataType == DataType::FLOAT16) { if (weight.dataType == DataType::FLOAT16) { FastllmCudaHalfMatMulFloat16(input, weight, bias, output, n, m, k); + } else if (weight.dataType == DataType::INT8){ + FastllmCudaHalfMatMulFloatInt8(input, weight, bias, output, n, m, k); } else { ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); } diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 3e377767..fcd5a87f 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -1794,7 +1794,6 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh checkCudaErrors("Error: CUDA error when moving bias to device!", state); weight.extraCudaData.push_back((void*)cudaBiasData); } - float *cudaScales = (float*)weight.extraCudaData[0]; uint8_t *cudaZeropoints = (uint8_t*)weight.extraCudaData[1]; float *cudaBiasData = (float*)weight.extraCudaData[2]; @@ -3768,6 +3767,90 @@ bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &wei return true; } +bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) { + if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) { + float *cudaScales; + cudaError_t state = cudaSuccess; + state = cudaMalloc(&cudaScales, k * sizeof(float)); + state = cudaMemcpy(cudaScales, weight.scales.data(), k * sizeof(float), cudaMemcpyHostToDevice); + weight.extraCudaHalfData.push_back((void*)cudaScales); + + uint8_t *cudaZeropoints; + state = cudaMalloc(&cudaZeropoints, k); + uint8_t *zeropoints = new uint8_t[k]; + for (int i = 0; i < k; i++) { + zeropoints[i] = weight.perChannelsConfigs[i].zeroPoint; + } + state = cudaMemcpy(cudaZeropoints, zeropoints, k, cudaMemcpyHostToDevice); + delete[] zeropoints; + weight.extraCudaHalfData.push_back((void*)cudaZeropoints); + + half *cudaBiasData; + state = cudaMalloc(&cudaBiasData, k * sizeof(half)); + if (bias.dims.size() > 0) { + float *tempBiasData; + state = cudaMalloc(&tempBiasData, k * sizeof(float)); + state = cudaMemcpy(tempBiasData, (uint8_t*)bias.cudaData, k * sizeof(float), cudaMemcpyDeviceToDevice); + int threadPerBlock = std::min(256, k); + FastllmCudaFloat2HalfKernel <<< (k - 1) / threadPerBlock + 1, threadPerBlock>>>(tempBiasData, cudaBiasData, k); + state = cudaFree(tempBiasData); + } else { + state = cudaMemset(cudaBiasData, 0, k * sizeof(half)); + } + checkCudaErrors("Error: CUDA error when moving bias to device!", state); + weight.extraCudaHalfData.push_back((void*)cudaBiasData); + } + float *cudaScales = (float*)weight.extraCudaHalfData[0]; + uint8_t *cudaZeropoints = (uint8_t*)weight.extraCudaHalfData[1]; + + half *cudaInput = (half*)FastllmCudaPrepareInput(input); + half *cudaOutput = (half*)FastllmCudaPrepareOutput(output); + + auto fastllmCublasHandle = getFastllmCublasHandle(); + half *cudaFp16Weight; + + cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half)); + + __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); + cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F; + cublasStatus_t status; + + int len = n * m; + int threadPerBlock = std::min(256, len); + + len = k * m; + + FastllmCudaInt82HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData, + cudaScales, + cudaZeropoints, + cudaFp16Weight, len, m); + + status = cublasGemmEx(fastllmCublasHandle, + CUBLAS_OP_T, CUBLAS_OP_N, + k, n, m, + &h_alpha, cudaFp16Weight, AType, + m, cudaInput, BType, + m, &h_beta, + cudaOutput, CType, + k, ComputeType, static_cast(CUBLAS_GEMM_DEFAULT)); + + if (status != CUBLAS_STATUS_SUCCESS) { + printf("Error: cublas error.\n"); + throw("cublas error"); + exit(0); + } + + if (bias.dims.size() > 0) { + half *cudaBiasData = (half*)weight.extraCudaHalfData[2]; + FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, cudaBiasData, k); + } + + FastllmCudaFree(cudaFp16Weight); + FastllmCudaFinishInput(input, cudaInput); + FastllmCudaFinishOutput(output, cudaOutput); + return true; +} + void FastllmCudaSetDevice(int gpu_id) { cudaSetDevice(gpu_id); }