Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ztxz16 committed Aug 9, 2024
1 parent 7fa20dd commit d3778c7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/devices/multicuda/fastllm-multicuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ namespace fastllm {
cudaMemcpy(curInput, cudaInput, n * m * sizeof(half), cudaMemcpyDeviceToDevice);
}

if (weightDataType == DataType::FLOAT16 && n < 8) {
if (weightDataType == DataType::FLOAT16 && n < 8 && false) {
LaunchFastllmGemmFp16Fp16(curInput, (half*)weight, curOutput, bias, n, m, len);
} else if (weightDataType == DataType::INT8 && n < 8) {
LaunchFastllmGemmFp16Int8(curInput, (uint8_t*)weight, curOutput, bias, scales, zeros, n, m, len);
} else if (weightDataType == DataType::INT4_NOZERO && n < 8) {
LaunchFastllmGemmFp16Int4NoZero(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len);
} else if (weightDataType == DataType::INT4_GROUP && n < 8) {
LaunchFastllmGemmFp16Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, k, group, groupCnt);
LaunchFastllmGemmFp16Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len, group, groupCnt);
} else {
__half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0);
auto fastllmCublasHandle = getFastllmCublasHandle();
Expand Down Expand Up @@ -175,14 +175,14 @@ namespace fastllm {
cudaMemcpy(curInput, cudaInput, n * m * sizeof(float), cudaMemcpyDeviceToDevice);
}

if (weightDataType == DataType::FLOAT16 && n < 8) {
if (weightDataType == DataType::FLOAT16 && n < 8 && false) {
LaunchFastllmGemmFp32Fp16(curInput, (half*)weight, curOutput, bias, n, m, len);
} else if (weightDataType == DataType::INT8 && n < 8) {
LaunchFastllmGemmFp32Int8(curInput, (uint8_t*)weight, curOutput, bias, scales, zeros, n, m, len);
} else if (weightDataType == DataType::INT4_NOZERO && n < 8) {
LaunchFastllmGemmFp32Int4NoZero(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len);
} else if (weightDataType == DataType::INT4_GROUP && n < 8) {
LaunchFastllmGemmFp32Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, k, group, groupCnt);
LaunchFastllmGemmFp32Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len, group, groupCnt);
} else {
auto fastllmCublasHandle = getFastllmCublasHandle();
half *cudaFp16Input, *cudaFp16Output;
Expand Down

0 comments on commit d3778c7

Please sign in to comment.