diff --git a/src/devices/cuda/fastllm-cuda.cu b/src/devices/cuda/fastllm-cuda.cu index b273074d..a62b1f7e 100644 --- a/src/devices/cuda/fastllm-cuda.cu +++ b/src/devices/cuda/fastllm-cuda.cu @@ -847,10 +847,10 @@ __global__ void FastllmSoftmaxKernelInner1WithCausalMask(T* input, T *output, in FastllmSoftmaxKernelInner1Func (input + o * channels, output + o * channels, min(channels, o + base + 1), maxp + o, sump + o); } -template +template __global__ void FastllmSoftmaxKernelBatchInner1(uint8_t** pointer) { int o = blockIdx.x; - FastllmSoftmaxKernelInner1Func ((float*)pointer[o * 3], (float*)pointer[o * 3 + 1], + FastllmSoftmaxKernelInner1Func ((T*)pointer[o * 3], (T*)pointer[o * 3 + 1], (int)((size_t)pointer[o * 3 + 2]), nullptr, nullptr); } @@ -1630,6 +1630,115 @@ __global__ void FastllmCatBatchKernel(uint8_t **inputs, uint8_t *output, int out } } +template +__global__ void FastllmHalfMatMulTransBBatchKernel(uint8_t** pointer, float alpha) { + int id = blockIdx.x; + half *input0 = (half*)pointer[id * 8 + 0]; + half *input1 = (half*)pointer[id * 8 + 1]; + half *output = (half*)pointer[id * 8 + 2]; + int n = (int)((size_t)pointer[id * 8 + 3]); + int m = (int)((size_t)pointer[id * 8 + 4]); + int k = (int)((size_t)pointer[id * 8 + 5]); + int input0Stride = (int)((size_t)pointer[id * 8 + 6]); + int input1Stride = (int)((size_t)pointer[id * 8 + 7]); + + int tid = threadIdx.x; +/* + const int pera = 8, perb = 8; + __shared__ float sa[pera][128], sb[perb][128], sc[pera][perb]; + for (int sta = 0; sta < n; sta += pera) { + for (int stb = 0; stb < k; stb += perb) { + for (int i = 0; i < pera; i++) { + if (sta + i < n) { + sa[i][tid] = (float)input0[(sta + i) * input0Stride + tid]; + } else { + sa[i][tid] = 0; + } + } + for (int i = 0; i < perb; i++) { + if (stb + i < k) { + sb[i][tid] = (float)input1[(stb + i) * input1Stride + tid]; + } else { + sb[i][tid] = 0; + } + } + __syncthreads(); + + __syncthreads(); + } + } +*/ + + int pera = 4, perb = 4; + float cura[4][4], curb[4][4], curc[4][4]; + int cnta = (n - 1) / pera + 1, cntb = (k - 1) / perb + 1; + for (int taskId = tid; taskId < cnta * cntb; taskId += THREAD_PER_BLOCK) { + int taska = taskId / cntb, taskb = taskId % cntb; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + cura[i][j] = 0; + curb[i][j] = 0; + curc[i][j] = 0; + } + } + + for (int l = 0; l < m; l += 4) { + for (int a = taska * pera; a < (taska + 1) * pera && a < n; a++) { +#pragma unroll + for (int x = 0; x < 4; x++) { + cura[a - taska * pera][x] = (float)input0[a * input0Stride + l + x]; + } + } + for (int b = taskb * perb; b < (taskb + 1) * perb && b < k; b++) { +#pragma unroll + for (int x = 0; x < 4; x++) { + curb[b - taskb * perb][x] = (float)input1[b * input1Stride + l + x]; + } + } +#pragma unroll + for (int i = 0; i < 4; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { +#pragma unroll + for (int k = 0; k < 4; k++) { + curc[i][j] += cura[i][k] * curb[j][k]; + } + } + } + } + + if ((taska + 1) * pera <= n && (taskb + 1) * perb <= k) { +#pragma unroll + for (int i = 0; i < 4; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + output[(taska * pera + i) * k + (taskb * perb + j)] = (half)(curc[i][j] * alpha); + } + } + } else { + for (int i = 0; i < pera && taska * pera + i < n; i++) { + for (int j = 0; j < perb && taskb * perb + j < k; j++) { + output[(taska * pera + i) * k + (taskb * perb + j)] = (half)(curc[i][j] * alpha); + } + } + } + } +/* + int tid = threadIdx.x; + for (int i = 0; i < n; i++) { + half *curInput0 = input0 + i * input0Stride; + for (int j = tid; j < k; j += THREAD_PER_BLOCK) { + half *curInput1 = input1 + j * input1Stride; + float sum = 0.0; + for (int l = 0; l < m; l++) { + sum += (float)curInput0[l] * (float)curInput1[l]; + } + output[i * k + j] = (half)(sum * alpha); + } + } +*/ +} + template __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) { int id = blockIdx.x; @@ -1714,6 +1823,90 @@ __global__ void FastllmMatMulTransBBatchKernel(uint8_t** pointer, float alpha) { */ } +template +__global__ void FastllmHalfMatMulKernel(uint8_t** pointer, float alpha) { + int id = blockIdx.x; + half *input0 = (half*)pointer[id * 8 + 0]; + half *input1 = (half*)pointer[id * 8 + 1]; + half *output = (half*)pointer[id * 8 + 2]; + int n = (int)((size_t)pointer[id * 8 + 3]); + int m = (int)((size_t)pointer[id * 8 + 4]); + int k = (int)((size_t)pointer[id * 8 + 5]); + int input0Stride = (int)((size_t)pointer[id * 8 + 6]); + int input1Stride = (int)((size_t)pointer[id * 8 + 7]); + + int tid = threadIdx.x; + int pera = 4, perb = 4; + float cura[4][4], curb[4][4], curc[4][4]; + int cnta = (n - 1) / pera + 1, cntb = (k - 1) / perb + 1; + for (int taskId = tid; taskId < cnta * cntb; taskId += THREAD_PER_BLOCK) { + int taska = taskId / cntb, taskb = taskId % cntb; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + cura[i][j] = 0; + curb[i][j] = 0; + curc[i][j] = 0; + } + } + + for (int l = 0; l < m; l += 4) { + for (int a = taska * pera; a < (taska + 1) * pera && a < n; a++) { +#pragma unroll + for (int x = 0; x < 4; x++) { + cura[a - taska * pera][x] = (l + x < m ? (float)input0[a * input0Stride + l + x] : 0.f); + } + } + for (int b = taskb * perb; b < (taskb + 1) * perb && b < k; b++) { +#pragma unroll + for (int x = 0; x < 4; x++) { + curb[b - taskb * perb][x] = (l + x < m ? (float)input1[(l + x) * input1Stride + b] : 0.f); + } + } + +#pragma unroll + for (int i = 0; i < 4; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { +#pragma unroll + for (int k = 0; k < 4; k++) { + curc[i][j] += cura[i][k] * curb[j][k]; + } + } + } + } + + if ((taska + 1) * pera <= n && (taskb + 1) * perb <= k) { +#pragma unroll + for (int i = 0; i < 4; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + output[(taska * pera + i) * k + (taskb * perb + j)] = (half)(curc[i][j] * alpha); + } + } + } else { + for (int i = 0; i < pera && taska * pera + i < n; i++) { + for (int j = 0; j < perb && taskb * perb + j < k; j++) { + output[(taska * pera + i) * k + (taskb * perb + j)] = (half)(curc[i][j] * alpha); + } + } + } + } +/* + int tid = threadIdx.x; + for (int i = 0; i < n; i++) { + half *curInput0 = input0 + i * input0Stride; + for (int j = tid; j < k; j += THREAD_PER_BLOCK) { + half *curInput1 = input1 + j; + float sum = 0.0; + for (int l = 0; l < m; l++) { + sum += (float)curInput0[l] * (float)curInput1[l * input1Stride]; + } + output[i * k + j] = (half)(sum * alpha); + } + } +*/ +} + template __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) { int id = blockIdx.x; @@ -1799,71 +1992,6 @@ __global__ void FastllmMatMulKernel(uint8_t** pointer, float alpha) { */ } -template -__global__ void FastllmAttentionBatchKernel(float** pointer, float scale, int group) { - const int params = 16; - int id = blockIdx.x; - float *qd = (float*) pointer[id * params + 0]; - float *kd = (float*) pointer[id * params + 1]; - float *vd = (float*) pointer[id * params + 2]; - float *maskd = (float*) pointer[id * params + 3]; - float *od = (float*) pointer[id * params + 4]; - int q1 = (int)(unsigned long long)pointer[id * params + 5]; - int q2 = (int)(unsigned long long)pointer[id * params + 6]; - int k1 = (int)(unsigned long long)pointer[id * params + 7]; - int v2 = (int)(unsigned long long)pointer[id * params + 8]; - int qstride = (int)(unsigned long long)pointer[id * params + 9]; - int kstride = (int)(unsigned long long)pointer[id * params + 10]; - int vstride = (int)(unsigned long long)pointer[id * params + 11]; - int ostride = (int)(unsigned long long)pointer[id * params + 12]; - float *qk = (float*)pointer[id * params + 13]; - float *temp = (float*)pointer[id * params + 14]; - int q0 = (int)(unsigned long long)pointer[id * params + 15]; - - for (int o = 0; o < q0; o++) { - qd += o * qstride; - kd += (o / group) * kstride; - vd += (o / group) * vstride; - od += o * ostride; - qk += o * k1; - temp += o * k1; - - for (int i = 0; i < q1; i++) { - for (int j = threadIdx.x; j < k1; j += THREAD_PER_BLOCK) { - if (maskd && maskd[i * k1 + j] > 0.99) { - qk[j] = -10000; - continue; - } - float sum = 0.0f; - float *tempQd = qd + i * q2, *tempKd = kd + j * q2; - for (int l = 0; l < q2; l++) { - sum += tempQd[l] * tempKd[l]; - } - qk[j] = sum * scale; - } - __syncthreads(); - FastllmSoftmaxKernelInner1Func(qk, temp, k1); - __syncthreads(); - for (int j = threadIdx.x; j < v2; j += THREAD_PER_BLOCK) { - float *curInput1 = vd + j; - float sum = 0.0; - for (int l = 0; l < k1; l++) { - sum += temp[l] * curInput1[l * v2]; - } - od[i * v2 + j] = sum; - } - __syncthreads(); - } - - qd -= o * qstride; - kd -= (o / group) * kstride; - vd -= (o / group) * vstride; - od -= o * ostride; - qk -= o * k1; - temp -= o * k1; - } -} - void *FastllmCudaPrepareInput(const fastllm::Data &input) { void *ret; if (input.dataDevice == fastllm::DataDevice::CUDA) { @@ -2928,7 +3056,7 @@ bool FastllmCudaSoftmaxBatch(fastllm::Data **inputs, fastllm::Data **outputs, in } cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * total * 3, cudaMemcpyHostToDevice); - FastllmSoftmaxKernelBatchInner1 <256> <<>> (pointers); + FastllmSoftmaxKernelBatchInner1 <<>> (pointers); FastllmCudaFree(pointers); delete[] cpuPointers; @@ -3625,15 +3753,16 @@ bool FastllmCudaApplyLognAttn (fastllm::Data &input, fastllm::Data &lognAttn, fa return true; } -bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Data **v, +template +bool DoFastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Data **v, fastllm::Data **mask, fastllm::Data **output, int group, float scale, int batch) { int k0 = k[0]->dims[0]; size_t memSum = 0; for (int b = 0; b < batch; b++) { memSum += q[b]->dims[0] * q[b]->dims[1] * k[b]->dims[1]; } - float *mem = (float*) FastllmCudaMalloc(memSum * sizeof(float)); - float **qk = new float*[batch]; + T *mem = (T*) FastllmCudaMalloc(memSum * sizeof(T)); + T **qk = new T*[batch]; memSum = 0; for (int b = 0; b < batch; b++) { int s = q[b]->dims[0] * q[b]->dims[1] * k[b]->dims[1]; @@ -3646,9 +3775,9 @@ bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Da uint8_t ** cpuPointers = new uint8_t*[batch * k0 * 8]; for (int b = 0; b < batch; b++) { for (int i = 0; i < k0; i++) { - cpuPointers[(b * k0 + i) * 8 + 0] = (uint8_t *) q[b]->cudaData + i * group * q[b]->dims[1] * q[b]->dims[2] * sizeof(float); - cpuPointers[(b * k0 + i) * 8 + 1] = (uint8_t *) k[b]->cudaData + i * k[b]->strides[0] * sizeof(float); - cpuPointers[(b * k0 + i) * 8 + 2] = (uint8_t *) qk[b] + i * group * q[b]->dims[1] * k[b]->dims[1] * sizeof(float); + cpuPointers[(b * k0 + i) * 8 + 0] = (uint8_t *) q[b]->cudaData + i * group * q[b]->dims[1] * q[b]->dims[2] * sizeof(T); + cpuPointers[(b * k0 + i) * 8 + 1] = (uint8_t *) k[b]->cudaData + i * k[b]->strides[0] * sizeof(T); + cpuPointers[(b * k0 + i) * 8 + 2] = (uint8_t *) qk[b] + i * group * q[b]->dims[1] * k[b]->dims[1] * sizeof(T); cpuPointers[(b * k0 + i) * 8 + 3] = (uint8_t *) (size_t) (group * q[b]->dims[1]); cpuPointers[(b * k0 + i) * 8 + 4] = (uint8_t *) (size_t) q[b]->dims[2]; cpuPointers[(b * k0 + i) * 8 + 5] = (uint8_t *) (size_t) k[b]->dims[1]; @@ -3657,7 +3786,11 @@ bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Da } } cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * batch * k0 * 8, cudaMemcpyHostToDevice); - FastllmMatMulTransBBatchKernel <128> <<>> (pointers, scale); + if (typeid(T) == typeid(half)) { + FastllmHalfMatMulTransBBatchKernel <128> <<>> (pointers, scale); + } else { + FastllmMatMulTransBBatchKernel <128> <<>> (pointers, scale); + } FastllmCudaFree(pointers); delete[] cpuPointers; } @@ -3682,19 +3815,20 @@ bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Da } } cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * total * 3, cudaMemcpyHostToDevice); - FastllmSoftmaxKernelBatchInner1 <256> <<>> (pointers); + FastllmSoftmaxKernelBatchInner1 <<>> (pointers); FastllmCudaFree(pointers); delete[] cpuPointers; } + if (true) { uint8_t ** pointers = (uint8_t**)FastllmCudaMalloc(sizeof(uint8_t*) * batch * k0 * 8); uint8_t ** cpuPointers = new uint8_t*[batch * k0 * 8]; for (int b = 0; b < batch; b++) { for (int i = 0; i < k0; i++) { - cpuPointers[(b * k0 + i) * 8 + 0] = (uint8_t *) qk[b] + i * group * q[b]->dims[1] * k[b]->dims[1] * sizeof(float); - cpuPointers[(b * k0 + i) * 8 + 1] = (uint8_t *) v[b]->cudaData + i * v[b]->strides[0] * sizeof(float); - cpuPointers[(b * k0 + i) * 8 + 2] = (uint8_t *) output[b]->cudaData + i * group * q[b]->dims[1] * v[b]->dims[2] * sizeof(float); + cpuPointers[(b * k0 + i) * 8 + 0] = (uint8_t *) qk[b] + i * group * q[b]->dims[1] * k[b]->dims[1] * sizeof(T); + cpuPointers[(b * k0 + i) * 8 + 1] = (uint8_t *) v[b]->cudaData + i * v[b]->strides[0] * sizeof(T); + cpuPointers[(b * k0 + i) * 8 + 2] = (uint8_t *) output[b]->cudaData + i * group * q[b]->dims[1] * v[b]->dims[2] * sizeof(T); cpuPointers[(b * k0 + i) * 8 + 3] = (uint8_t *) (size_t) (group * q[b]->dims[1]); cpuPointers[(b * k0 + i) * 8 + 4] = (uint8_t *) (size_t) k[b]->dims[1]; cpuPointers[(b * k0 + i) * 8 + 5] = (uint8_t *) (size_t) v[b]->dims[2]; @@ -3703,80 +3837,36 @@ bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Da } } cudaMemcpy(pointers, cpuPointers, sizeof(uint8_t*) * batch * k0 * 8, cudaMemcpyHostToDevice); - FastllmMatMulKernel <128> <<>> (pointers, 1.0f); + + if (typeid(T) == typeid(half)) { + FastllmHalfMatMulKernel <128> <<>> (pointers, 1.0f); + } else { + FastllmMatMulKernel <128> <<>> (pointers, 1.0f); + } FastllmCudaFree(pointers); delete[] cpuPointers; } FastllmCudaFree(mem); delete[] qk; -/* - { - const int params = 16; - float **pointers = (float **) FastllmCudaMalloc(sizeof(float *) * batch * params); - float **cpuPointers = new float *[batch * params]; - - float **qk = new float *[batch]; - float **temp = new float *[batch]; - for (int b = 0; b < batch; b++) { - qk[b] = (float *) FastllmCudaMalloc(q[b]->dims[0] * k[b]->dims[1] * sizeof(float)); - temp[b] = (float *) FastllmCudaMalloc(q[b]->dims[0] * k[b]->dims[1] * sizeof(float)); - - cpuPointers[b * params + 0] = (float *) q[b]->cudaData; - cpuPointers[b * params + 1] = (float *) k[b]->cudaData; - cpuPointers[b * params + 2] = (float *) v[b]->cudaData; - cpuPointers[b * params + 3] = (mask[b] && mask[b]->dims.size() > 0) ? (float *) mask[b]->cudaData : nullptr; - cpuPointers[b * params + 4] = (float *) output[b]->cudaData; - cpuPointers[b * params + 5] = (float *) (unsigned long long) q[b]->dims[1]; - cpuPointers[b * params + 6] = (float *) (unsigned long long) q[b]->dims[2]; - cpuPointers[b * params + 7] = (float *) (unsigned long long) k[b]->dims[1]; - cpuPointers[b * params + 8] = (float *) (unsigned long long) v[b]->dims[2]; - cpuPointers[b * params + 9] = (float *) (unsigned long long) q[b]->strides[0]; - cpuPointers[b * params + 10] = (float *) (unsigned long long) k[b]->strides[0]; - cpuPointers[b * params + 11] = (float *) (unsigned long long) v[b]->strides[0]; - cpuPointers[b * params + 12] = (float *) (unsigned long long) output[b]->strides[0]; - cpuPointers[b * params + 13] = (float *) (unsigned long long) qk[b]; - cpuPointers[b * params + 14] = (float *) (unsigned long long) temp[b]; - cpuPointers[b * params + 15] = (float *) (unsigned long long) q[b]->dims[0]; - } - - cudaMemcpy(pointers, cpuPointers, sizeof(float *) * batch * params, cudaMemcpyHostToDevice); - FastllmAttentionBatchKernel<256> <<< batch, 256 >>>(pointers, scale, group); - - for (int i = 0; i < batch; i++) { - FastllmCudaFree(qk[i]); - FastllmCudaFree(temp[i]); - } - delete[] qk; - delete[] temp; - - FastllmCudaFree(pointers); - delete[] cpuPointers; - } -*/ -/* - for (int b = 0; b < batch; b++) { - int q0 = q[b]->dims[0], q1 = q[b]->dims[1], q2 = q[b]->dims[2], k0 = k[b]->dims[0], k1 = k[b]->dims[1], v2 = v[b]->dims[2]; - float *qd = (float *) q[b]->cudaData; - float *kd = (float *) k[b]->cudaData; - float *vd = (float *) v[b]->cudaData; - float *maskd = (mask[b] && mask[b]->dims.size() > 0) ? (float *) mask[b]->cudaData : nullptr; - float *od = (float *) output[b]->cudaData; - int maskBatch = (mask[b] && mask[b]->dims.size() > 0) ? mask[b]->dims[0] : 1; - - float *qk = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float)); - float *temp = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float)); - FastllmAttentionKernel<256> <<>>(qd, kd, vd, maskd, od, - scale, q1, q2, k1, v2, - group, q[b]->strides[0], k[b]->strides[0], v[b]->strides[0], - output[b]->strides[0], - qk, temp); - } -*/ + DeviceSync(); return true; } +bool FastllmCudaAttentionBatch(fastllm::Data **q, fastllm::Data **k, fastllm::Data **v, + fastllm::Data **mask, fastllm::Data **output, int group, float scale, int batch) { + if (q[0]->dataType == fastllm::DataType::FLOAT32) { + return DoFastllmCudaAttentionBatch (q, k, v, mask, output, group, scale, batch); + } else if (q[0]->dataType == fastllm::DataType::FLOAT16) { + return DoFastllmCudaAttentionBatch (q, k, v, mask, output, group, scale, batch); + } else { + printf("Error: attention datatype error.\n"); + throw ("Error: attention datatype error."); + exit(0); + } +} + bool FastllmCudaSplitBatch(fastllm::Data &input, fastllm::Data **outputs, int axis) { int part = input.dims[axis]; int outer = input.Count(0) / input.Count(axis);