Skip to content

Commit

Permalink
silu, multo增加float16算子
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 10, 2024
1 parent c494755 commit 8501b1a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 29 deletions.
80 changes: 57 additions & 23 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3368,28 +3368,51 @@ namespace fastllm {
}
}

struct FP16SiluManager {
uint16_t dict[65536];

FP16SiluManager() {
for (uint16_t i = 0; i < 65535; i++) {
float x = half_to_float(i);
float y = x / (1.0 + expf(-x));
dict[i] = float_to_half(y);
}
}
} fp16SiluManager;

void CpuSiluOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
output.Allocate();
AssertInFastLLM(input.dataType == DataType::FLOAT32, "Silu error: Data's type should be float32.\n");
float *inputData = (float*)input.cpuData;
float *outputData = (float*)output.cpuData;
AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"Silu error: Data's type should be float32 or float16.\n");
int len = input.Count(0);
int i = 0;
#ifdef __aarch64__
float32x4_t c1 = vdupq_n_f32(1.0f);
for (; i + 3 < len; i += 4) {
float32x4_t vx = vld1q_f32(inputData + i);
float32x4_t vdiv = vaddq_f32(c1, exp_ps(vnegq_f32(vx)));
vx = vdivq_f32(vx, vdiv);
vst1q_f32(outputData + i, vx);
}
#endif
for (; i < len; i++) {
float x = inputData[i];
outputData[i] = x / (1.0 + expf(-x));

if (input.dataType == DataType::FLOAT16) {
uint16_t *inputData = (uint16_t*)input.cpuData;
uint16_t *outputData = (uint16_t*)output.cpuData;
for (int i = 0; i < len; i++) {
outputData[i] = fp16SiluManager.dict[inputData[i]];
}
} else {
float *inputData = (float*)input.cpuData;
float *outputData = (float*)output.cpuData;
int i = 0;
#ifdef __aarch64__
float32x4_t c1 = vdupq_n_f32(1.0f);
for (; i + 3 < len; i += 4) {
float32x4_t vx = vld1q_f32(inputData + i);
float32x4_t vdiv = vaddq_f32(c1, exp_ps(vnegq_f32(vx)));
vx = vdivq_f32(vx, vdiv);
vst1q_f32(outputData + i, vx);
}
#endif
for (; i < len; i++) {
float x = inputData[i];
outputData[i] = x / (1.0 + expf(-x));
}
}
}

Expand Down Expand Up @@ -3643,18 +3666,29 @@ namespace fastllm {
Data &input1 = *(datas.find("input1")->second);
AssertInFastLLM(input0.dims == input1.dims, "MulTo error: input's shape should be same.\n");

float *input0Data = (float*)input0.cpuData;
float *input1Data = (float*)input1.cpuData;

int len = input0.Count(0);
int inner = input1.Count(0);
AssertInFastLLM(len % inner == 0, "MulTo error: Data`s shape can`t perform MulTo operation.\n");
int round = (len / inner);
for (int j = 0; j < round; j++) {
for (int i = 0; i < len; i++) {
input0Data[i] *= input1Data[i];

if (input0.dataType == DataType::FLOAT16) {
uint16_t *input0Data = (uint16_t*)input0.cpuData;
uint16_t *input1Data = (uint16_t*)input1.cpuData;
for (int j = 0; j < round; j++) {
for (int i = 0; i < len; i++) {
input0Data[i] = float_to_half(fp16tofp32.dict[input0Data[i]] * fp16tofp32.dict[input1Data[i]]);
}
input0Data += inner;
}
} else {
float *input0Data = (float*)input0.cpuData;
float *input1Data = (float*)input1.cpuData;
for (int j = 0; j < round; j++) {
for (int i = 0; i < len; i++) {
input0Data[i] *= input1Data[i];
}
input0Data += inner;
}
input0Data += inner;
}
}

Expand Down
11 changes: 7 additions & 4 deletions src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ namespace fastllm {
output.Allocate();
AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"Swiglu error: Data's type should be float32.\n");
"Swiglu error: Data's type should be float32 or float16.\n");
FastllmCudaSwiglu(input, output);
}

Expand All @@ -597,7 +597,9 @@ namespace fastllm {
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
output.Allocate();
AssertInFastLLM(input.dataType == DataType::FLOAT32, "Silu error: Data's type should be float32.\n");
AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"Silu error: Data's type should be float32 or float16.\n");
FastllmCudaSilu(input, output);
}

Expand Down Expand Up @@ -633,8 +635,9 @@ namespace fastllm {
Data &input1 = *(datas.find("input1")->second);
float alpha = floatParams.find("alpha") != floatParams.end() ? floatParams.find("alpha")->second : 1.0;

AssertInFastLLM(input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32,
"MulTo error: Data's type should be float32.\n");
AssertInFastLLM((input0.dataType == DataType::FLOAT32 && input1.dataType == DataType::FLOAT32) ||
(input0.dataType == DataType::FLOAT16 && input1.dataType == DataType::FLOAT16),
"MulTo error: Data's type should be float32 or float16.\n");
AssertInFastLLM(input0.dims == input1.dims, "MulTo error: input's shape should be same.\n");
FastllmCudaMulTo(input0, input1, alpha);
}
Expand Down
27 changes: 25 additions & 2 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,14 @@ __global__ void FastllmSiluKernel(float* a, float *b, int len) {
}
}

__global__ void FastllmSiluKernel(half* a, half *b, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
float x = (float)a[idx];
b[idx] = (half)(x / (1.0 + expf(-x)));
}
}

__global__ void FastllmSwigluKernel(float* a, float *b, int len, int spatial, int mid) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
Expand Down Expand Up @@ -489,6 +497,13 @@ __global__ void FastllmMulToKernel(float* a, float *b, float alpha, int len) {
}
}

__global__ void FastllmMulToKernel(half* a, half *b, float alpha, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
a[idx] *= (half)((float)b[idx] * alpha);
}
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmAttentionMaskKernel(float* a, float *b, float maskValue, int n, int m, int spatial) {
int on = blockIdx.x / m;
Expand Down Expand Up @@ -2921,7 +2936,11 @@ bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output) {
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int threadPerBlock = std::min(256, len);
FastllmSiluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len);
if (input.dataType == fastllm::DataType::FLOAT32) {
FastllmSiluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len);
} else if (input.dataType == fastllm::DataType::FLOAT16) {
FastllmSiluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)cudaInput, (half*)cudaOutput, len);
}
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
Expand Down Expand Up @@ -2985,7 +3004,11 @@ bool FastllmCudaMulTo(fastllm::Data &input0, const fastllm::Data &input1, float
float *input1Data = (float *) FastllmCudaPrepareInput(input1);

int threadPerBlock = std::min(256, len);
FastllmMulToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len);
if (input0.dataType == fastllm::DataType::FLOAT32) {
FastllmMulToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len);
} else {
FastllmMulToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)cudaData, (half*)input1Data, alpha, len);
}
FastllmCudaFinishInput(input1, input1Data);
FastllmCudaFinishOutput(input0, cudaData);
return true;
Expand Down

0 comments on commit 8501b1a

Please sign in to comment.