Skip to content

Commit

Permalink
稍稍优化halfAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 8, 2024
1 parent 175ce89 commit 2619759
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 39 deletions.
7 changes: 6 additions & 1 deletion src/devices/cuda/cudadevicebatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,12 @@ namespace fastllm {
Data **masks = (Data**)(datas.find("mask")->second);
Data **outputs = (Data**)(datas.find("output")->second);

if (qs[0]->dataType == DataType::FLOAT32) {
long long aveLen = 0;
for (int i = 0; i < batch; i++) {
aveLen += ks[i]->dims[1];
}
aveLen /= batch;
if (qs[0]->dataType == DataType::FLOAT32 || aveLen < 512) {
for (int i = 0; i < batch; i++) {
outputs[i]->Allocate();
}
Expand Down
62 changes: 25 additions & 37 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void showError(cudaError_t result, char const* const message, const char* const


#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])
#define FETCH_FLOAT2(pointer) (reinterpret_cast<float2*>(&(pointer))[0])

typedef union __align__(16) {
uint2 in;
Expand Down Expand Up @@ -854,6 +855,14 @@ __global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer) {
(int)((size_t)pointer[o * 3 + 2]), nullptr, nullptr);
}

template <typename T, int THREAD_PER_BLOCK>
__global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer, int outer, int channels) {
int o = blockIdx.x;
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> ((T*)pointer[o / outer] + (o % outer) * channels, (T*)pointer[o / outer] + (o % outer) * channels,
channels, nullptr, nullptr);
}


template <int THREAD_PER_BLOCK>
__global__ void FastllmRMSNormKernelInner1(float *input, float *weight, float *output, int outer, int channels, float eps) {
int o = blockIdx.x;
Expand Down Expand Up @@ -1670,38 +1679,29 @@ __global__ void FastllmHalfMatMulTransBBatchKernel(uint8_t** pointer, float alph
*/

int pera = 4, perb = 4;
float cura[4][4], curb[4][4], curc[4][4];
half cura[4][4], curb[4][4];
float curc[4][4];
int cnta = (n - 1) / pera + 1, cntb = (k - 1) / perb + 1;
for (int taskId = tid; taskId < cnta * cntb; taskId += THREAD_PER_BLOCK) {
int taska = taskId / cntb, taskb = taskId % cntb;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
cura[i][j] = 0;
curb[i][j] = 0;
curc[i][j] = 0;
curc[i][j] = 0.0f;
}
}

for (int l = 0; l < m; l += 4) {
for (int a = taska * pera; a < (taska + 1) * pera && a < n; a++) {
#pragma unroll
for (int x = 0; x < 4; x++) {
cura[a - taska * pera][x] = (float)input0[a * input0Stride + l + x];
}
FETCH_FLOAT2(cura[a - taska * pera]) = FETCH_FLOAT2(input0[a * input0Stride + l]);
}
for (int b = taskb * perb; b < (taskb + 1) * perb && b < k; b++) {
#pragma unroll
for (int x = 0; x < 4; x++) {
curb[b - taskb * perb][x] = (float)input1[b * input1Stride + l + x];
}
FETCH_FLOAT2(curb[b - taskb * perb]) = FETCH_FLOAT2(input1[b * input1Stride + l]);
}
#pragma unroll

for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
#pragma unroll
for (int k = 0; k < 4; k++) {
curc[i][j] += cura[i][k] * curb[j][k];
curc[i][j] += (float)cura[i][k] * (float)curb[j][k];
}
}
}
Expand Down Expand Up @@ -3796,29 +3796,17 @@ bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::
}

if (true) {
int total = 0;
for (int b = 0; b < batch; b++) {
int outer = q[b]->dims[0] * q[b]->dims[1];
total += outer;
}
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * total * 3);
uint8_t ** cpuPointers = new uint8_t*[total * 3];
int cur = 0;
for (int b = 0; b < batch; b++) {
int outer = q[b]->dims[0] * q[b]->dims[1];
int channels = k[b]->dims[1];
for (int o = 0; o < outer; o++) {
cpuPointers[cur * 3 + 0] = (uint8_t*)(qk[b] + o * channels);
cpuPointers[cur * 3 + 1] = (uint8_t*)(qk[b] + o * channels);
cpuPointers[cur * 3 + 2] = (uint8_t*)((size_t)channels);
cur++;
}
int outer = q[0]->dims[0] * q[0]->dims[1], channels = k[0]->dims[1];
uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch);
cudaMemcpy(pointers, qk, sizeof(uint8_t*) * batch, cudaMemcpyHostToDevice);
if (channels < 128) {
FastllmSoftmaxKernelBatchInner1 <T, 32> <<<batch * outer, 32>>> (pointers, outer, channels);
} else if (channels < 512) {
FastllmSoftmaxKernelBatchInner1 <T, 64> <<<batch * outer, 64>>> (pointers, outer, channels);
} else {
FastllmSoftmaxKernelBatchInner1 <T, 128> <<<batch * outer, 128>>> (pointers, outer, channels);
}
cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * total * 3, cudaMemcpyHostToDevice);
FastllmSoftmaxKernelBatchInner1 <T, 256> <<<total, 256>>> (pointers);

FastllmCudaFree(pointers);
delete[] cpuPointers;
}

if (true) {
Expand Down
2 changes: 1 addition & 1 deletion src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ namespace fastllm {
if (isPrompt) {
cnt += it.second->currentTokens.size();

if (cnt > 300) {
if (cnt > 1024) {
break;
}
// break;
Expand Down

0 comments on commit 2619759

Please sign in to comment.