diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index 0330a9ff..52b489a8 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -3368,28 +3368,51 @@ namespace fastllm { } } + struct FP16SiluManager { + uint16_t dict[65536]; + + FP16SiluManager() { + for (uint16_t i = 0; i < 65535; i++) { + float x = half_to_float(i); + float y = x / (1.0 + expf(-x)); + dict[i] = float_to_half(y); + } + } + } fp16SiluManager; + void CpuSiluOp::Run(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { Data &input = *(datas.find("input")->second); Data &output = *(datas.find("output")->second); output.Allocate(); - AssertInFastLLM(input.dataType == DataType::FLOAT32, "Silu error: Data's type should be float32.\n"); - float *inputData = (float*)input.cpuData; - float *outputData = (float*)output.cpuData; + AssertInFastLLM(input.dataType == DataType::FLOAT32 || + input.dataType == DataType::FLOAT16, + "Silu error: Data's type should be float32 or float16.\n"); int len = input.Count(0); - int i = 0; -#ifdef __aarch64__ - float32x4_t c1 = vdupq_n_f32(1.0f); - for (; i + 3 < len; i += 4) { - float32x4_t vx = vld1q_f32(inputData + i); - float32x4_t vdiv = vaddq_f32(c1, exp_ps(vnegq_f32(vx))); - vx = vdivq_f32(vx, vdiv); - vst1q_f32(outputData + i, vx); - } -#endif - for (; i < len; i++) { - float x = inputData[i]; - outputData[i] = x / (1.0 + expf(-x)); + + if (input.dataType == DataType::FLOAT16) { + uint16_t *inputData = (uint16_t*)input.cpuData; + uint16_t *outputData = (uint16_t*)output.cpuData; + for (int i = 0; i < len; i++) { + outputData[i] = fp16SiluManager.dict[inputData[i]]; + } + } else { + float *inputData = (float*)input.cpuData; + float *outputData = (float*)output.cpuData; + int i = 0; + #ifdef __aarch64__ + float32x4_t c1 = vdupq_n_f32(1.0f); + for (; i + 3 < len; i += 4) { + float32x4_t vx = vld1q_f32(inputData + i); + float32x4_t vdiv = vaddq_f32(c1, exp_ps(vnegq_f32(vx))); + vx = vdivq_f32(vx, vdiv); + vst1q_f32(outputData + i, vx); + } + #endif + for (; i < len; i++) { + float x = inputData[i]; + outputData[i] = x / (1.0 + expf(-x)); + } } } @@ -3643,18 +3666,29 @@ namespace fastllm { Data &input1 = *(datas.find("input1")->second); AssertInFastLLM(input0.dims == input1.dims, "MulTo error: input's shape should be same.\n"); - float *input0Data = (float*)input0.cpuData; - float *input1Data = (float*)input1.cpuData; - int len = input0.Count(0); int inner = input1.Count(0); AssertInFastLLM(len % inner == 0, "MulTo error: Data`s shape can`t perform MulTo operation.\n"); int round = (len / inner); - for (int j = 0; j < round; j++) { - for (int i = 0; i < len; i++) { - input0Data[i] *= input1Data[i]; + + if (input0.dataType == DataType::FLOAT16) { + uint16_t *input0Data = (uint16_t*)input0.cpuData; + uint16_t *input1Data = (uint16_t*)input1.cpuData; + for (int j = 0; j < round; j++) { + for (int i = 0; i < len; i++) { + input0Data[i] = float_to_half(fp16tofp32.dict[input0Data[i]] * fp16tofp32.dict[input1Data[i]]); + } + input0Data += inner; + } + } else { + float *input0Data = (float*)input0.cpuData; + float *input1Data = (float*)input1.cpuData; + for (int j = 0; j < round; j++) { + for (int i = 0; i < len; i++) { + input0Data[i] *= input1Data[i]; + } + input0Data += inner; } - input0Data += inner; } } diff --git a/src/devices/cuda/cudadevice.cpp b/src/devices/cuda/cudadevice.cpp index 0fe002ae..7e73ed9a 100644 --- a/src/devices/cuda/cudadevice.cpp +++ b/src/devices/cuda/cudadevice.cpp @@ -588,7 +588,7 @@ namespace fastllm { output.Allocate(); AssertInFastLLM(input.dataType == DataType::FLOAT32 || input.dataType == DataType::FLOAT16, - "Swiglu error: Data's type should be float32.\n"); + "Swiglu error: Data's type should be float32 or float16.\n"); FastllmCudaSwiglu(input, output); } @@ -597,7 +597,9 @@ namespace fastllm { Data &input = *(datas.find("input")->second); Data &output = *(datas.find("output")->second); output.Allocate(); - AssertInFastLLM(input.dataType == DataType::FLOAT32, "Silu error: Data's type should be float32.\n"); + AssertInFastLLM(input.dataType == DataType::FLOAT32 || + input.dataType == DataType::FLOAT16, + "Silu error: Data's type should be float32 or float16.\n"); FastllmCudaSilu(input, output); } @@ -633,8 +635,9 @@ namespace fastllm { Data &input1 = *(datas.find("input1")->second); float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0; - AssertInFastLLM(input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32, - "MulTo error: Data's type should be float32.\n"); + AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) || + (input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16), + "MulTo error: Data's type should be float32 or float16.\n"); AssertInFastLLM(input0.dims == input1.dims, "MulTo error: input's shape should be same.\n"); FastllmCudaMulTo(input0, input1, alpha); } diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index b0ad43e8..7bec310f 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -413,6 +413,14 @@ __global__ void FastllmSiluKernel(float* a, float *b, int len) { } } +__global__ void FastllmSiluKernel(half* a, half *b, int len) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < len) { + float x = (float)a[idx]; + b[idx] = (half)(x / (1.0 + expf(-x))); + } +} + __global__ void FastllmSwigluKernel(float* a, float *b, int len, int spatial, int mid) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { @@ -489,6 +497,13 @@ __global__ void FastllmMulToKernel(float* a, float *b, float alpha, int len) { } } +__global__ void FastllmMulToKernel(half* a, half *b, float alpha, int len) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < len) { + a[idx] *= (half)((float)b[idx] * alpha); + } +} + template __global__ void FastllmAttentionMaskKernel(float* a, float *b, float maskValue, int n, int m, int spatial) { int on = blockIdx.x / m; @@ -2921,7 +2936,11 @@ bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output) { float *cudaInput = (float *) FastllmCudaPrepareInput(input); float *cudaOutput = (float *) FastllmCudaPrepareOutput(output); int threadPerBlock = std::min(256, len); - FastllmSiluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len); + if (input.dataType == fastllm::DataType::FLOAT32) { + FastllmSiluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len); + } else if (input.dataType == fastllm::DataType::FLOAT16) { + FastllmSiluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)cudaInput, (half*)cudaOutput, len); + } FastllmCudaFinishInput(input, cudaInput); FastllmCudaFinishOutput(output, cudaOutput); return true; @@ -2985,7 +3004,11 @@ bool FastllmCudaMulTo(fastllm::Data &input0, const fastllm::Data &input1, float float *input1Data = (float *) FastllmCudaPrepareInput(input1); int threadPerBlock = std::min(256, len); - FastllmMulToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len); + if (input0.dataType == fastllm::DataType::FLOAT32) { + FastllmMulToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len); + } else { + FastllmMulToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)cudaData, (half*)input1Data, alpha, len); + } FastllmCudaFinishInput(input1, input1Data); FastllmCudaFinishOutput(input0, cudaData); return true;