Skip to content

Commit

Permalink
tp加速
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Aug 9, 2024
1 parent 28c6fdc commit 7fa20dd
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 190 deletions.
187 changes: 96 additions & 91 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2198,6 +2198,12 @@ void FastllmCudaFinishOutput(fastllm::Data &output, void *data) {
DeviceSync();
}

void LaunchFastllmGemmFp32Int8(float *input, uint8_t *weight, float *output, float *bias, float *scales, uint8_t *zeros, int n, int m, int k) {
for (int i = 0; i < n; i++) {
FastllmGemvInt8Kernel2<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, zeros, m, k);
}
}

bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
if (weight.cudaData == nullptr || weight.extraCudaData.size() == 0) {
float *cudaScales;
Expand Down Expand Up @@ -2310,21 +2316,19 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
FastllmCudaFree(cudaFp16Weight);
#endif
} else {
for (int i = 0; i < n; i++) {
FastllmGemvInt8Kernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m,
(uint8_t *) weight.cudaData,
cudaOutput + i * k,
cudaBiasData,
cudaScales,
cudaZeropoints,
m, k);
}
LaunchFastllmGemmFp32Int8(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaZeropoints, n, m, k);
}
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

void LaunchFastllmGemvInt4Kernel2(float *input, uint8_t *weight, float *output, float *bias, float *scales, uint8_t *zeros, int n, int m, int k) {
for (int i = 0; i < n; i++) {
FastllmGemvInt4Kernel2<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, zeros, m, k);
}
}

bool FastllmCudaMatMulFloatInt4(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
if (weight.cudaData == nullptr || weight.extraCudaData.size() == 0) {
float *cudaScales;
Expand Down Expand Up @@ -2360,21 +2364,19 @@ bool FastllmCudaMatMulFloatInt4(const fastllm::Data &input, fastllm::Data &weigh

float *cudaInput = (float*)FastllmCudaPrepareInput(input);
float *cudaOutput = (float*)FastllmCudaPrepareOutput(output);
LaunchFastllmGemvInt4Kernel2(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaZeropoints, n, m, k);

for (int i = 0; i < n; i++) {
FastllmGemvInt4Kernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m,
(uint8_t *) weight.cudaData,
cudaOutput + i * k,
cudaBiasData,
cudaScales,
cudaZeropoints,
m, k);
}
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

void LaunchFastllmGemmFp32Int4Group(float *input, uint8_t *weight, float *output, float *bias, float *scales, float *mins, int n, int m, int k, int group, int groupCnt) {
for (int i = 0; i < n; i++) {
FastllmGemvInt4GroupKernel2<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, mins, m, k, group, groupCnt);
}
}

bool FastllmCudaMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output,
int n, int m, int k) {
int group = weight.group, groupCnt = weight.groupCnt;
Expand Down Expand Up @@ -2460,21 +2462,19 @@ bool FastllmCudaMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data &
FastllmCudaFree(cudaFp16Output);
FastllmCudaFree(cudaFp16Weight);
} else {
for (int i = 0; i < n; i++) {
FastllmGemvInt4GroupKernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m,
(uint8_t *) weight.cudaData,
cudaOutput + i * k,
cudaBiasData,
cudaScales,
cudaMins,
m, k, group, groupCnt);
}
LaunchFastllmGemmFp32Int4Group(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaMins, n, m, k, group, groupCnt);
}
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

void LaunchFastllmGemmFp32Int4NoZero(float *input, uint8_t *weight, float *output, float *bias, float *scales, float *mins, int n, int m, int k) {
for (int i = 0; i < n; i++) {
FastllmGemvInt4NoZeroKernel1<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, mins, m, k);
}
}

bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
if (weight.cudaData == nullptr || weight.extraCudaData.size() == 0) {
float *cudaScales;
Expand Down Expand Up @@ -2590,16 +2590,9 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
FastllmCudaFree(cudaFp16Weight);
#endif
} else {
for (int i = 0; i < n; i++) {
FastllmGemvInt4NoZeroKernel1<256, 1> <<< k, 256 >>>(cudaInput + i * m,
(uint8_t *) weight.cudaData,
cudaOutput + i * k,
cudaBiasData,
cudaScales,
cudaMins,
m, k);
}
LaunchFastllmGemmFp32Int4NoZero(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaMins, n, m, k);
}

FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
Expand Down Expand Up @@ -2658,6 +2651,27 @@ bool FastllmCudaMatMulFloat32(const fastllm::Data &input, fastllm::Data &weight,
return true;
}

void LaunchFastllmGemmFp32Fp16(float *input, half *weight, float *output, float *bias, int n, int m, int k) {
if (n == 1) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 2) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 3) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 4) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 5) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 6) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 7) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else {
printf("Error: LaunchFastllmGemmFp32Fp16: n > 7.\n");
exit(0);
}
}

bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
if (weight.cudaData == nullptr || weight.extraCudaData.size() == 0) {
float *cudaBiasData;
Expand All @@ -2675,20 +2689,8 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight,
float *cudaInput = (float*)FastllmCudaPrepareInput(input);
float *cudaOutput = (float*)FastllmCudaPrepareOutput(output);

if (n == 1) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 2) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 3) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 4) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 5) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 6) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 7) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
if (n < 8) {
LaunchFastllmGemmFp32Fp16(cudaInput, (half*)weight.cudaData, cudaOutput, cudaBiasData, n, m, k);
} else {
auto fastllmCublasHandle = getFastllmCublasHandle();
//cudaDeviceSynchronize();
Expand Down Expand Up @@ -4157,6 +4159,27 @@ bool FastllmCudaBatchMatMulBatch(void **i0s, void **i1s, void **os,
return true;
}

void LaunchFastllmGemmFp16Fp16(half *input, half *weight, half *output, half *bias, int n, int m, int k) {
if (n == 1) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 2) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 3) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 4) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 5) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 6) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else if (n == 7) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(input, weight, output, bias, m, k);
} else {
printf("Error: LaunchFastllmGemmFp16Fp16: n > 7.\n");
exit(0);
}
}

bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
if (weight.cudaData == nullptr ||
(weight.extraCudaHalfData.size() == 0 && bias.dims.size() > 0)) {
Expand All @@ -4181,20 +4204,8 @@ bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &wei
half *cudaOutput = (half *) FastllmCudaPrepareOutput(output);
half *cudaBiasData = bias.dims.size() == 0 ? nullptr : (half *) weight.extraCudaHalfData[0];

if (n == 1) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 2) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 3) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 4) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 5) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 6) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 7) {
FastllmGemvFp16Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
if (n < 8) {
LaunchFastllmGemmFp16Fp16(cudaInput, (half*)weight.cudaData, cudaOutput, cudaBiasData, n, m, k);
} else {
__half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0);
auto fastllmCublasHandle = getFastllmCublasHandle();
Expand Down Expand Up @@ -4224,6 +4235,12 @@ bool FastllmCudaHalfMatMulFloat16(const fastllm::Data &input, fastllm::Data &wei
return true;
}

void LaunchFastllmGemmFp16Int8(half *input, uint8_t *weight, half *output, half *bias, float *scales, uint8_t *zeros, int n, int m, int k) {
for (int i = 0; i < n; i++) {
FastllmGemvFp16Int8Kernel2 <256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, zeros, m, k);
}
}

bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) {
weight.extraCudaHalfData.push_back((void*)weight.extraCudaData[0]);
Expand Down Expand Up @@ -4293,22 +4310,20 @@ bool FastllmCudaHalfMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &w
FastllmCudaFree(cudaFp16Weight);
} else {
half *cudaBiasData = bias.dims.size() > 0 ? (half*)weight.extraCudaHalfData[2] : nullptr;
for (int i = 0; i < n; i++) {
FastllmGemvFp16Int8Kernel2 <256, 1> <<< k, 256 >>>(cudaInput + i * m,
(uint8_t *) weight.cudaData,
cudaOutput + i * k,
cudaBiasData,
cudaScales,
cudaZeropoints,
m, k);
}
LaunchFastllmGemmFp16Int8(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaZeropoints, n, m, k);
}

FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

void LaunchFastllmGemmFp16Int4Group(half *input, uint8_t *weight, half *output, half *bias, float *scales, float *mins, int n, int m, int k, int group, int groupCnt) {
for (int i = 0; i < n; i++) {
FastllmGemvHalfInt4GroupKernel<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, mins, m, k, group, groupCnt);
}
}

bool FastllmCudaHalfMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
int group = weight.group, groupCnt = weight.groupCnt;
if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) {
Expand Down Expand Up @@ -4379,21 +4394,19 @@ bool FastllmCudaHalfMatMulFloatInt4Group(const fastllm::Data &input, fastllm::Da
FastllmCudaDirectFree(cudaFp16Weight);
} else {
half *cudaBiasData = (half*)weight.extraCudaHalfData[2];
for (int i = 0; i < n; i++) {
FastllmGemvHalfInt4GroupKernel<256, 1> <<< k, 256 >>>(cudaInput + i * m,
(uint8_t *) weight.cudaData,
cudaOutput + i * k,
cudaBiasData,
cudaScales,
cudaMins,
m, k, group, groupCnt);
}
LaunchFastllmGemmFp16Int4Group(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaMins, n, m, k, group, groupCnt);
}
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

void LaunchFastllmGemmFp16Int4NoZero(half *input, uint8_t *weight, half *output, half *bias, float *scales, float *mins, int n, int m, int k) {
for (int i = 0; i < n; i++) {
FastllmGemvFp16Int4NoZeroKernel2<256, 1> <<< k, 256 >>>(input + i * m, weight, output + i * k, bias, scales, mins, m, k);
}
}

bool FastllmCudaHalfMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k) {
if (weight.cudaData == nullptr || weight.extraCudaHalfData.size() == 0) {
weight.extraCudaHalfData.push_back((void*)weight.extraCudaData[0]);
Expand Down Expand Up @@ -4463,15 +4476,7 @@ bool FastllmCudaHalfMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::D
FastllmCudaDirectFree(cudaFp16Weight);
} else {
half *cudaBiasData = (half*)weight.extraCudaHalfData[2];
for (int i = 0; i < n; i++) {
FastllmGemvFp16Int4NoZeroKernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m,
(uint8_t *) weight.cudaData,
cudaOutput + i * k,
cudaBiasData,
cudaScales,
cudaMins,
m, k);
}
LaunchFastllmGemmFp16Int4NoZero(cudaInput, (uint8_t*)weight.cudaData, cudaOutput, cudaBiasData, cudaScales, cudaMins, n, m, k);
}
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
Expand Down
Loading

0 comments on commit 7fa20dd

Please sign in to comment.