Skip to content

Commit

Permalink
Merge pull request #473 from jiewlmrh/master
Browse files Browse the repository at this point in the history
 对于int4g模型,增加对fp16输入的支持
  • Loading branch information
ztxz16 authored Jul 5, 2024
2 parents dead943 + f52960c commit 3ab78c7
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/devices/cuda/fastllm-cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co
const fastllm::Data &mask, const fastllm::Data &output, int group, float scale);
bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k);
bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k);
bool FastllmCudaHalfMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k);

void FastllmCudaSetDevice(int gpu_id);
#ifdef __cplusplus
Expand Down
4 changes: 3 additions & 1 deletion src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ namespace fastllm {
FastllmCudaHalfMatMulFloat16(input, weight, bias, output, n, m, k);
} else if (weight.dataType == DataType::INT8){
FastllmCudaHalfMatMulFloatInt8(input, weight, bias, output, n, m, k);
} else {
} else if (weight.dataType == DataType::INT4_GROUP){
FastllmCudaHalfMatMulFloatInt4Group(input, weight, bias, output, n, m, k);
}else {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
}
} else if (input.dataType == DataType::FLOAT32) {
Expand Down
85 changes: 85 additions & 0 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3886,6 +3886,91 @@ bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &w
return true;
}

bool FastllmCudaHalfMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
int group = weight.group, groupCnt = weight.groupCnt;
if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) {
float *cudaScales;
cudaError_t state = cudaSuccess;
state = cudaMalloc(&cudaScales, k * group * sizeof(float));
state = cudaMemcpy(cudaScales, weight.scales.data(), k * group * sizeof(float), cudaMemcpyHostToDevice);
weight.extraCudaHalfData.push_back((void*)cudaScales);

float *cudaMins;
state = cudaMalloc(&cudaMins, k * group * sizeof(float));
float *mins = new float[k * group];
for (int i = 0; i < k * group; i++) {
mins[i] = weight.perChannelsConfigs[i].min;
}
state = cudaMemcpy(cudaMins, mins, k * group * sizeof(float), cudaMemcpyHostToDevice);
delete[] mins;
weight.extraCudaHalfData.push_back((void*)cudaMins);

half *cudaBiasData;
state = cudaMalloc(&cudaBiasData, k * sizeof(half));
if (bias.dims.size() > 0) {
float *tempBiasData;
state = cudaMalloc(&tempBiasData, k * sizeof(float));
state = cudaMemcpy(tempBiasData, (uint8_t*)bias.cudaData, k * sizeof(float), cudaMemcpyDeviceToDevice);
int threadPerBlock = std::min(256, k);
FastllmCudaFloat2HalfKernel <<< (k - 1) / threadPerBlock + 1, threadPerBlock>>>(tempBiasData, cudaBiasData, k);
state = cudaFree(tempBiasData);
} else {
state = cudaMemset(cudaBiasData, 0, k * sizeof(half));
}
checkCudaErrors("Error: CUDA error when moving bias to device!", state);
weight.extraCudaHalfData.push_back((void*)cudaBiasData);
}
float *cudaScales = (float*)weight.extraCudaHalfData[0];
float *cudaMins = (float*)weight.extraCudaHalfData[1];

half *cudaInput = (half*)FastllmCudaPrepareInput(input);
half *cudaOutput = (half*)FastllmCudaPrepareOutput(output);

auto fastllmCublasHandle = getFastllmCublasHandle();
half *cudaFp16Weight;

cudaFp16Weight = (half *) FastllmCudaMalloc(k * m * sizeof(half));

__half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0);
cudaDataType_t AType = CUDA_R_16F, BType = CUDA_R_16F, CType = CUDA_R_16F, ComputeType = CUDA_R_16F;
cublasStatus_t status;

int len = n * m;
int threadPerBlock = std::min(256, len);

len = k * m;

FastllmCudaInt4Group2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData,
cudaScales,
cudaMins,
cudaFp16Weight, len, m, group, groupCnt);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
&h_alpha, cudaFp16Weight, AType,
m, cudaInput, BType,
m, &h_beta,
cudaOutput, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));

if (status != CUBLAS_STATUS_SUCCESS) {
printf("Error: cublas error.\n");
throw("cublas error");
exit(0);
}

if (bias.dims.size() > 0) {
half *cudaBiasData = (half*)weight.extraCudaHalfData[2];
FastllmCudaBiasKernel <<< n, 256 >>> (cudaOutput, cudaBiasData, k);
}

FastllmCudaFree(cudaFp16Weight);
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

void FastllmCudaSetDevice(int gpu_id) {
cudaSetDevice(gpu_id);
}

0 comments on commit 3ab78c7

Please sign in to comment.