diff --git a/src/devices/multicuda/fastllm-multicuda.cu b/src/devices/multicuda/fastllm-multicuda.cu index 3e0a6ee..9730c34 100644 --- a/src/devices/multicuda/fastllm-multicuda.cu +++ b/src/devices/multicuda/fastllm-multicuda.cu @@ -84,14 +84,14 @@ namespace fastllm { cudaMemcpy(curInput, cudaInput, n * m * sizeof(half), cudaMemcpyDeviceToDevice); } - if (weightDataType == DataType::FLOAT16 && n < 8) { + if (weightDataType == DataType::FLOAT16 && n < 8 && false) { LaunchFastllmGemmFp16Fp16(curInput, (half*)weight, curOutput, bias, n, m, len); } else if (weightDataType == DataType::INT8 && n < 8) { LaunchFastllmGemmFp16Int8(curInput, (uint8_t*)weight, curOutput, bias, scales, zeros, n, m, len); } else if (weightDataType == DataType::INT4_NOZERO && n < 8) { LaunchFastllmGemmFp16Int4NoZero(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len); } else if (weightDataType == DataType::INT4_GROUP && n < 8) { - LaunchFastllmGemmFp16Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, k, group, groupCnt); + LaunchFastllmGemmFp16Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len, group, groupCnt); } else { __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); auto fastllmCublasHandle = getFastllmCublasHandle(); @@ -175,14 +175,14 @@ namespace fastllm { cudaMemcpy(curInput, cudaInput, n * m * sizeof(float), cudaMemcpyDeviceToDevice); } - if (weightDataType == DataType::FLOAT16 && n < 8) { + if (weightDataType == DataType::FLOAT16 && n < 8 && false) { LaunchFastllmGemmFp32Fp16(curInput, (half*)weight, curOutput, bias, n, m, len); } else if (weightDataType == DataType::INT8 && n < 8) { LaunchFastllmGemmFp32Int8(curInput, (uint8_t*)weight, curOutput, bias, scales, zeros, n, m, len); } else if (weightDataType == DataType::INT4_NOZERO && n < 8) { LaunchFastllmGemmFp32Int4NoZero(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len); } else if (weightDataType == DataType::INT4_GROUP && n < 8) { - LaunchFastllmGemmFp32Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, k, group, groupCnt); + LaunchFastllmGemmFp32Int4Group(curInput, (uint8_t*)weight, curOutput, bias, scales, mins, n, m, len, group, groupCnt); } else { auto fastllmCublasHandle = getFastllmCublasHandle(); half *cudaFp16Input, *cudaFp16Output;