Skip to content

Commit

Permalink
加速graph llm,达到和llama基本相同的速度
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Dec 2, 2024
1 parent ddbd6db commit f711b32
Show file tree
Hide file tree
Showing 7 changed files with 427 additions and 82 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ token
/localtest/
/third_party/tfacc/driver/tfacc2/result
/.chainlit
/.files
/.files
*.o
11 changes: 7 additions & 4 deletions include/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,13 @@ namespace fastllm {

// 执行计算图
void RunComputeGraph (const ComputeGraph &graph,
const std::map <std::string, int> &deviceMap,
std::map <std::string, Data*> inputs,
std::map <std::string, Data*> weights,
std::map <std::string, Data*> outputs);
const std::map <std::string, int> &deviceMap,
const std::map <std::string, Data*> &inputs,
const std::map <std::string, Data*> &weights,
const std::map <std::string, Data*> &outputs,
std::vector <std::vector <Data*> > &pastKeys,
std::vector <std::vector <Data*> > &pastValues,
std::vector <Data*> &masks);
}

#endif //FASTLLM_GRAPH_H
6 changes: 5 additions & 1 deletion src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,8 @@ namespace fastllm {

AssertInFastLLM(weight.dims.size() == 2, "Embedding's weight's dim should be 2.\n");
AssertInFastLLM(weight.dataType == DataType::FLOAT32 ||
weight.dataType == DataType::BFLOAT16, "Embedding's weight's type should be float32 or bfloat16.\n");
weight.dataType == DataType::FLOAT16 ||
weight.dataType == DataType::BFLOAT16, "Embedding's weight's type should be float32 or float16 or bfloat16.\n");
AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"Embedding's input's type should be float32 or float16.\n");
Expand All @@ -1041,6 +1042,9 @@ namespace fastllm {
dims.push_back(embSize);

output.dataType = input.dataType;
if (weight.dataType == DataType::FLOAT16) {
output.dataType = DataType::FLOAT16;
}
output.Resize(dims);
}

Expand Down
130 changes: 120 additions & 10 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ void GpuQK(half *q, half *k, half *qk, int qlen, int klen, int dim, float scale,
HalfFC <BQ, DIM, BK> <<<gridDim, blockDim>>> (q, k, qk, qlen, dim, klen, (half)scale, base);
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmCudaFloatEmbeddingKernel(float *input, float *weight, float *output, int embSize) {
template <int THREAD_PER_BLOCK, typename T>
__global__ void FastllmCudaFloatEmbeddingKernel(float *input, T *weight, T *output, int embSize) {
input += blockIdx.x;
output += blockIdx.x * embSize;
int token = (int)(input[0] + 1e-5);
Expand Down Expand Up @@ -467,8 +467,8 @@ __global__ void FastllmSiluKernel(float* a, float *b, int len) {
__global__ void FastllmSiluKernel(half* a, half *b, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
float x = __half2float(a[idx]);
b[idx] = __float2half(x / (1.0 + expf(-x)));
half x = a[idx];
b[idx] = __hdiv(x, __hadd(__float2half(1.0), hexp(-x)));
}
}

Expand All @@ -477,7 +477,7 @@ __global__ void FastllmSwigluKernel(float* a, float *b, int len, int spatial, in
if (idx < len) {
int id = idx / mid * spatial + idx % mid;
float x = a[id], y = a[id + mid];
b[idx] = (x / (1.0 + expf(-x))) * y;
b[idx] = (x / (1.0f + expf(-x))) * y;
}
}

Expand Down Expand Up @@ -3134,13 +3134,13 @@ void FastllmCudaMemcpyBetweenDevices(int dstId, void *dst, int srcId, void *src,
delete[] cpuData;
}
checkCudaErrors("Error: CUDA error when copy Between GPUs!", state);
//cudaDeviceSynchronize();
DeviceSync();
}

void FastllmCudaMemcpy2DDeviceToDevice(void * dst, size_t dpitch, const void * src,
size_t spitch, size_t width, size_t height) {
cudaMemcpy2D(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice);
//cudaDeviceSynchronize();
DeviceSync();
}

template <int THREAD_PER_BLOCK>
Expand Down Expand Up @@ -3223,7 +3223,7 @@ bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int threadPerBlock = std::min(256, len);
int threadPerBlock = std::min(1024, len);
if (input.dataType == fastllm::DataType::FLOAT32) {
FastllmSiluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len);
} else if (input.dataType == fastllm::DataType::FLOAT16) {
Expand All @@ -3240,7 +3240,7 @@ bool FastllmCudaSwiglu(const fastllm::Data &input, fastllm::Data &output) {
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int spatial = input.Count(input.dims.size() - 1), mid = spatial / 2;

int threadPerBlock = std::min(256, len);
int threadPerBlock = std::min(1024, len);
if (input.dataType == fastllm::DataType::FLOAT32) {
FastllmSwigluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len, spatial, mid);
} else if (input.dataType == fastllm::DataType::FLOAT16) {
Expand Down Expand Up @@ -3274,7 +3274,7 @@ bool FastllmCudaAddTo(fastllm::Data &input0, const fastllm::Data &input1, float
float *cudaData = (float *) FastllmCudaPrepareInput(input0);
float *input1Data = (float *) FastllmCudaPrepareInput(input1);

int threadPerBlock = std::min(256, len);
int threadPerBlock = std::min(1024, len);
if (input0.dataType == fastllm::DataType::FLOAT32) {
FastllmAddToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len);
} else if (input0.dataType == fastllm::DataType::FLOAT16) {
Expand Down Expand Up @@ -3584,24 +3584,28 @@ bool FastllmCudaPermute(fastllm::Data &input, const std::vector<int> &axis) {
}

FastllmCudaFree(tempData);
DeviceSync();
return true;
}

bool FastllmFloatToHalf(void *a, void *b, int len) {
int threadPerBlock = std::min(256, len);
FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((float*)a, (half*)b, len);
DeviceSync();
return true;
}

bool FastllmHalfToFloat(void *a, void *b, int len) {
int threadPerBlock = std::min(256, len);
FastllmCudaHalf2FloatKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)a, (float*)b, len);
DeviceSync();
return true;
}

bool FastllmBF16ToFloat(void *a, void *b, int len) {
int threadPerBlock = std::min(256, len);
FastllmCudaBF162FloatKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint16_t*)a, (float*)b, len);
DeviceSync();
return true;
}

Expand All @@ -3616,6 +3620,10 @@ bool FastllmCudaEmbedding(const fastllm::Data &input, const fastllm::Data &weigh
float *outputData = (float *) dstOutputData;
float *weightData = (float *) weight.cudaData;
FastllmCudaFloatEmbeddingKernel <128> <<<inputLen, 128>>> (inputData, weightData, outputData, embSize);
} else if (weight.dataType == fastllm::DataType::FLOAT16) {
half *outputData = (half *) dstOutputData;
half *weightData = (half *) weight.cudaData;
FastllmCudaFloatEmbeddingKernel <128> <<<inputLen, 128>>> (inputData, weightData, outputData, embSize);
} else if (weight.dataType == fastllm::DataType::BFLOAT16) {
std::vector <float> cpuInputData = std::vector <float> (inputLen, 0.0f);
FastllmCudaCopyFromDeviceToHost(cpuInputData.data(), inputData, cpuInputData.size() * sizeof(float));
Expand All @@ -3627,8 +3635,11 @@ bool FastllmCudaEmbedding(const fastllm::Data &input, const fastllm::Data &weigh
FastllmBF16ToFloat(outputData + i * embSize, weightData + token * embSize, embSize);
}
}
} else {

}

DeviceSync();
return true;
}

Expand Down Expand Up @@ -4137,6 +4148,105 @@ bool FastllmCudaRepeatPenalty (fastllm::Data &input, fastllm::Data &penalty, fas
template <typename T>
bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Data **v,
fastllm::Data **mask, fastllm::Data **output, int group, float scale, int batch) {
if (false) {
half beta = __float2half_rn(0.0f), one = __float2half_rn(1.0f), hscale = __float2half_rn(scale);
int q0 = q[0]->dims[0], q1 = q[0]->dims[1], q2 = q[0]->dims[2], k0 = k[0]->dims[0], k1 = k[0]->dims[1], v2 = v[0]->dims[2];
for (int i = 0; i < batch; i++) {
q1 = std::max(q1, q[i]->dims[1]);
k1 = std::max(k1, k[i]->dims[1]);
}

half *allKeys = (half*) FastllmCudaMalloc(batch * k0 * k1 * q2 * sizeof(half));
half *allValues = (half*) FastllmCudaMalloc(batch * k0 * k1 * v2 * sizeof(half));

std::vector <void*> dsts, srcs;
std::vector <size_t> dpitchs, spitchs, widths, heights;
for (int i = 0; i < batch; i++) {
dsts.push_back((uint8_t *) (allKeys + i * k0 * k1 * q2));
dpitchs.push_back(k1 * q2 * sizeof(half));
srcs.push_back(k[i]->cudaData);
spitchs.push_back(k[i]->strides[0] * sizeof(half));
widths.push_back(k[i]->dims[1] * q2 * sizeof(half));
heights.push_back(k0);

dsts.push_back((uint8_t *) (allValues + i * k0 * k1 * v2));
dpitchs.push_back(k1 * v2 * sizeof(half));
srcs.push_back(v[i]->cudaData);
spitchs.push_back(v[i]->strides[0] * sizeof(half));
widths.push_back(v[i]->dims[1] * v2 * sizeof(half));
heights.push_back(k0);
}
FastllmCudaMemcpy2DDeviceToDeviceBatch(dsts.data(), dpitchs.data(), srcs.data(), spitchs.data(), widths.data(), heights.data(), dsts.size());
/*
for (int i = 0; i < batch; i++) {
cudaMemcpy2D(
allKeys + i * k0 * k1 * q2, k1 * q2 * sizeof(half),
k[i]->cudaData, k[i]->strides[0] * sizeof(half),
k[i]->dims[1] * q2 * sizeof(half), k0,
cudaMemcpyDeviceToDevice
);
cudaMemcpy2D(
allValues + i * k0 * k1 * v2, k1 * v2 * sizeof(half),
v[i]->cudaData, v[i]->strides[0] * sizeof(half),
v[i]->dims[1] * v2 * sizeof(half), k0,
cudaMemcpyDeviceToDevice
);
}
*/
half *qd = (half*)q[0]->cudaData;
half *od = (half*)output[0]->cudaData;
half *qk = (half *) FastllmCudaMalloc(batch * q0 * q1 * k1 * sizeof(half));
half *temp = (half *) FastllmCudaMalloc(batch * q0 * q1 * k1 * sizeof(half));
auto fastllmCublasHandle = getFastllmCublasHandle();
cublasStatus_t status;

status = cublasHgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k1, q1 * group, q2, &hscale,
allKeys, q2, k1 * q2,
qd, q2, group * q1 * q2,
&beta,
qk, k1, k1 * q1 * group, batch * q0 / group);
if (status != CUBLAS_STATUS_SUCCESS) {
printf("status = %d\n", (int) status);
printf("Error: cublas error during MatMulTransB in Attention operator.\n");
throw ("cublas error");
exit(0);
}

int outer = batch * q0 * q1;
if (k1 < 8) {
FastllmSoftmaxKernelInner1<1> <<< outer, 1 >>>(qk, temp, outer, k1);
} else if (k1 < 64) {
FastllmSoftmaxKernelInner1<8> <<< outer, 8 >>>(qk, temp, outer, k1);
} else if (k1 < 512) {
FastllmSoftmaxKernelInner1<64> <<< outer, 64 >>>(qk, temp, outer, k1);
} else {
FastllmSoftmaxKernelInner1<256> <<< outer, 256 >>>(qk, temp, outer, k1);
}

status = cublasHgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_N, CUBLAS_OP_N,
v2, q1 * group, k1, &one,
allValues, v2, k1 * v2,
temp, k1, k1 * q1 * group,
&beta,
od, v2, v2 * q1 * group, batch * q0 / group);
if (status != CUBLAS_STATUS_SUCCESS) {
printf("status = %d\n", (int) status);
printf("Error: cublas error during MatMul in Attention operator.\n");
throw ("cublas error");
exit(0);
}

FastllmCudaFree(allKeys);
FastllmCudaFree(allValues);
FastllmCudaFree(qk);
FastllmCudaFree(temp);
DeviceSync();
return true;
}

int k0 = k[0]->dims[0];
size_t memSum = 0;
for (int b = 0; b < batch; b++) {
Expand Down
1 change: 1 addition & 0 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ namespace fastllm {
this->Resize(ori.dims);
this->Allocate();
} else {
this->expansionDims.clear();
this->Resize(ori.dims);
this->Allocate();
}
Expand Down
Loading

0 comments on commit f711b32

Please sign in to comment.