From 02de7431c862e757c7aba934c54925aa56aa6e18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Mon, 23 Sep 2024 17:05:18 +0800 Subject: [PATCH] add gelu for cuda --- include/devices/cuda/cudadevice.h | 4 ++++ include/devices/cuda/fastllm-cuda.cuh | 3 ++- src/devices/cuda/cudadevice.cpp | 10 ++++++++++ src/devices/cuda/fastllm-cuda.cu | 21 ++++++++++++++++++++- 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/include/devices/cuda/cudadevice.h b/include/devices/cuda/cudadevice.h index 5ba86527..ee0a405d 100644 --- a/include/devices/cuda/cudadevice.h +++ b/include/devices/cuda/cudadevice.h @@ -82,6 +82,10 @@ namespace fastllm { void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); }; + class CudaGeluOp : BaseOperator { + void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); + }; + class CudaGeluNewOp : BaseOperator { void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); }; diff --git a/include/devices/cuda/fastllm-cuda.cuh b/include/devices/cuda/fastllm-cuda.cuh index 936d8b04..7aabc777 100644 --- a/include/devices/cuda/fastllm-cuda.cuh +++ b/include/devices/cuda/fastllm-cuda.cuh @@ -32,7 +32,8 @@ bool FastllmBF16ToFloat(void *a, void *b, int len); bool FastllmCudaEmbedding(const fastllm::Data &input, const fastllm::Data &weight, fastllm::Data &output); bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const fastllm::Data &v, const fastllm::Data &mask, const fastllm::Data &output, int group, float scale); -bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output); +bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output);\ +bool FastllmCudaGelu(const fastllm::Data &input, fastllm::Data &output); bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output); bool FastllmCudaSwiglu(const fastllm::Data &input, fastllm::Data &output); bool FastllmCudaMul(const fastllm::Data &input, float v, fastllm::Data &output); diff --git a/src/devices/cuda/cudadevice.cpp b/src/devices/cuda/cudadevice.cpp index f524824a..f249fe1d 100644 --- a/src/devices/cuda/cudadevice.cpp +++ b/src/devices/cuda/cudadevice.cpp @@ -25,6 +25,7 @@ namespace fastllm { this->ops["MatMul"] = (BaseOperator*)(new CudaMatMulOp()); this->ops["MatMulTransB"] = (BaseOperator*)(new CudaMatMulTransBOp()); this->ops["SoftMax"] = (BaseOperator*)(new CudaSoftMaxOp()); + this->ops["Gelu"] = (BaseOperator*)(new CudaGeluOp()); this->ops["GeluNew"] = (BaseOperator*)(new CudaGeluNewOp()); this->ops["Silu"] = (BaseOperator*)(new CudaSiluOp()); this->ops["Swiglu"] = (BaseOperator*)(new CudaSwigluOp()); @@ -571,6 +572,15 @@ namespace fastllm { FastllmCudaGeluNew(input, output); } + void CudaGeluOp::Run(const std::string &opType, const fastllm::DataDict &datas, + const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { + Data &input = *(datas.find("input")->second); + Data &output = *(datas.find("output")->second); + output.Allocate(); + AssertInFastLLM(input.dataType == DataType::FLOAT32, "GeluNew error: Data's type should be float32.\n"); + FastllmCudaGelu(input, output); + } + void CudaSwigluOp::Reshape(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { Data &input = *(datas.find("input")->second); diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index 6b694230..7ef768d1 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -433,6 +433,14 @@ __global__ void FastllmCudaBiasKernel(half *a, half *bias, int k) { } __global__ void FastllmGeluKernel(float* a, float *b, int len) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < len) { + float x = a[idx]; + b[idx] = x * 0.5f * (1.0f + erf(x / sqrt(2.0))); + } +} + +__global__ void FastllmGeluNewKernel(float* a, float *b, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < len) { float x = a[idx]; @@ -3127,7 +3135,7 @@ void FastllmCudaMemcpy2DDeviceToDeviceBatch(void ** dsts, size_t * dpitchs, voi DeviceSync(); } -bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output) { +bool FastllmCudaGelu(const fastllm::Data &input, fastllm::Data &output) { int len = input.Count(0); float *cudaInput = (float *) FastllmCudaPrepareInput(input); float *cudaOutput = (float *) FastllmCudaPrepareOutput(output); @@ -3138,6 +3146,17 @@ bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output) { return true; } +bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output) { + int len = input.Count(0); + float *cudaInput = (float *) FastllmCudaPrepareInput(input); + float *cudaOutput = (float *) FastllmCudaPrepareOutput(output); + int threadPerBlock = std::min(256, len); + FastllmGeluNewKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len); + FastllmCudaFinishInput(input, cudaInput); + FastllmCudaFinishOutput(output, cudaOutput); + return true; +} + bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output) { int len = input.Count(0); float *cudaInput = (float *) FastllmCudaPrepareInput(input);