Skip to content

Commit

Permalink
优化一些算子
Browse files Browse the repository at this point in the history
  • Loading branch information
ztxz16 committed May 24, 2024
1 parent 4aec44d commit bbc0c14
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/devices/cuda/cudadevicebatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ namespace fastllm {
dpitchs.push_back(input0Stride * unitSize);
srcs.push_back(input1.cudaData);
spitchs.push_back(input1Stride * unitSize);
widths.push_back(input1.dims[axis] * inner * unitSize);
widths.push_back(inner * unitSize);
heights.push_back(outer);
}

Expand Down
89 changes: 84 additions & 5 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,61 @@ __global__ void FastllmGemvFp32Fp16Kernel2(float *A, half *B, float *C, float *b
}
}

template <int THREAD_PER_BLOCK, int PART>
__global__ void FastllmGemvFp32Fp16Kernel2MultiRow(float *A, half *B, float *C, float *bias, int m, int k) {
__shared__ float sdata[PART][THREAD_PER_BLOCK];
unsigned int tid = threadIdx.x;
const half zero = __float2half_rn(0.0);
float4 regA;
union_half4 regB;

// 1. 计算
int st = blockIdx.x;
int p = st;
#pragma unroll
for (int x = 0; x < PART; x++) sdata[x][tid] = 0;

const half *baseB = B + p * m;
#pragma unroll
for (int i = tid * 4; i < m; i += THREAD_PER_BLOCK * 4) {
#pragma unroll
for (int x = 0; x < PART; x++) {
regA = FETCH_FLOAT4(A[i + x * m]);
regB.in = *reinterpret_cast<const uint2 *>(baseB + i);
float sum = 0.0f;
if (i < m)
sum += regA.x * __low2float(regB.out2[0]);
if (i + 1 < m)
sum += regA.y * __high2float(regB.out2[0]);
if (i + 2 < m)
sum += regA.z * __low2float(regB.out2[1]);
if (i + 3 < m)
sum += regA.w * __high2float(regB.out2[1]);
sdata[x][tid] += sum;
}
}
__syncthreads();
float diff = 0.0f;
for (unsigned int s = THREAD_PER_BLOCK/2; s > 0; s >>= 1) {
if (tid < s) {
#pragma unroll
for (int x = 0; x < PART; x++) {
float other = sdata[x][tid + s] - diff;
float sumTmp = sdata[x][tid] + other;
diff = (sumTmp - sdata[x][tid]) - other;
sdata[x][tid] = sumTmp;
}
}
__syncthreads();
}

if (tid == 0) {
#pragma unroll
for (int x = 0; x < PART; x++) C[p + k * x] = sdata[x][0] + __ldg(bias + p);
}
__syncthreads();
}

template <int THREAD_PER_BLOCK, int PART>
__global__ void FastllmGemvInt8Kernel2(float *A, uint8_t *B, float *C,
float *bias, float *scales, uint8_t *zeros,
Expand Down Expand Up @@ -2012,7 +2067,21 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight,
float *cudaInput = (float*)FastllmCudaPrepareInput(input);
float *cudaOutput = (float*)FastllmCudaPrepareOutput(output);

if (n > 1) {
if (n == 1) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 2) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 2> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 3) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 3> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 4) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 4> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 5) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 5> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 6) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 6> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else if (n == 7) {
FastllmGemvFp32Fp16Kernel2MultiRow<256, 7> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
} else {
auto fastllmCublasHandle = getFastllmCublasHandle();
//cudaDeviceSynchronize();
half *cudaFp16Input, *cudaFp16Output;
Expand Down Expand Up @@ -2073,8 +2142,6 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight,
FastllmCudaFree(cudaFp16Input);
FastllmCudaFree(cudaFp16Output);
#endif
} else {
FastllmGemvFp32Fp16Kernel2<256, 1> <<< k, 256 >>>(cudaInput, (half *) weight.cudaData, cudaOutput, cudaBiasData, m, k);
}

FastllmCudaFinishInput(input, cudaInput);
Expand Down Expand Up @@ -2292,6 +2359,17 @@ void FastllmCudaMemcpy2DDeviceToDevice(void * dst, size_t dpitch, const void *
//cudaDeviceSynchronize();
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmMemcpy2DKernel (uint8_t * dst, size_t dpitch, uint8_t * src,
size_t spitch, size_t width, size_t height) {
int id = blockIdx.x;
dst += id * dpitch;
src += id * spitch;
for (int i = threadIdx.x; i < width; i += THREAD_PER_BLOCK) {
dst[i] = src[i];
}
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmMemcpyBatchKernel (uint8_t** pointer) {
int id = blockIdx.x;
Expand Down Expand Up @@ -2323,7 +2401,7 @@ void FastllmCudaMemcpy2DDeviceToDeviceBatch(void ** dsts, size_t * dpitchs, voi
}
}
cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * total * 3, cudaMemcpyHostToDevice);
FastllmMemcpyBatchKernel <128> <<<total, 128>>> (pointers);
FastllmMemcpyBatchKernel <256> <<<total, 256>>> (pointers);

FastllmCudaFree(pointers);
delete[] cpuPointers;
Expand Down Expand Up @@ -3100,7 +3178,7 @@ bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Da
qk[b] = mem + memSum;
memSum += s;
}

if (true) {
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * k0 * 8);
uint8_t ** cpuPointers = new uint8_t*[batch * k0 * 8];
Expand All @@ -3121,6 +3199,7 @@ bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Da
FastllmCudaFree(pointers);
delete[] cpuPointers;
}

if (true) {
int total = 0;
for (int b = 0; b < batch; b++) {
Expand Down

0 comments on commit bbc0c14

Please sign in to comment.