Skip to content

Commit

Permalink
chatglm使用融合的attention
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyuyang committed Oct 9, 2023
1 parent baf2404 commit 7d6ebe8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
5 changes: 3 additions & 2 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,21 +275,22 @@ namespace fastllm {
Data &output = *(datas.find("output")->second);
int group = intParams.find("group") != intParams.end() ? intParams.find("group")->second : 1;
float scale = floatParams.find("scale") != floatParams.end() ? floatParams.find("scale")->second : 1.0;

output.Allocate();
int q0 = q.dims[0], q1 = q.dims[1], q2 = q.dims[2], k0 = k.dims[0], k1 = k.dims[1], v2 = v.dims[2];
float *qd = (float*)q.cpuData;
float *kd = (float*)k.cpuData;
float *vd = (float*)v.cpuData;
float *maskd = (datas.find("mask")->second && mask.dims.size() > 0) ? (float*)mask.cpuData : nullptr;
float *od = (float*)output.cpuData;
int batch = (mask.dims.size() == 3 ? mask.dims[0] : 1);
int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0));
std::fill(od, od + output.Count(0), 0.0f);
auto pool = GetPool();
std::vector<std::future<void> > futures;
for (int o = 0; o < q0; o++) {
futures.push_back(pool->Submit(SingleAttention,
qd + o * q.strides[0], kd + (o / group) * k.strides[0], vd + (o / group) * v.strides[0],
maskd ? (maskd + o / (q0 / mask.dims[0])) : maskd, od + o * output.strides[0], scale,
maskd + (o / (q0 / batch)) * maskStride, od + o * output.strides[0], scale,
q1, q2, k1, v2));
}
for (int o = 0; o < futures.size(); o++) {
Expand Down
7 changes: 4 additions & 3 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1948,8 +1948,8 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
float *vd = (float*)v.cudaData;
float *maskd = mask.dims.size() > 0 ? (float*)mask.cudaData : nullptr;
float *od = (float*)output.cudaData;
int batch = mask.dims.size() > 0 ? mask.dims[0] : 1;

int batch = (mask.dims.size() == 3) ? mask.dims[0] : 1;
int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0));
if (false) {
float *qk = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float));
float *temp = (float *) FastllmCudaMalloc(q0 * k1 * sizeof(float));
Expand All @@ -1968,6 +1968,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
auto fastllmCublasHandle = getFastllmCublasHandle();
cublasStatus_t status;


for (int i = 0; i < q0; i++) {
status = cublasSgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
Expand All @@ -1984,7 +1985,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
}

if (maskd) {
SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * q1 * k1, -10000, q1 * k1);
SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, -10000, q1 * k1);
}

int outer = q1;
Expand Down
7 changes: 4 additions & 3 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,14 @@ namespace fastllm {
std::vector<int> outputSize = {q.dims[1], q.dims[2], q.dims[0], pastKey.dims[1]};
q.Reshape({q.dims[0], q.dims[1] * q.dims[2], q.dims[3]});
PermuteSelf(q, {1, 0, 2});
//Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1);

Attention(q, pastKey, pastValue, attentionMask, contextLayer, q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1);

/*
// 1.2 Attention
// 1.2.0 q * k^T
q.Reshape({pastKey.dims[0], -1, q.dims[2]});
MatMulTransB(q, pastKey, attnProbs, 1.0 / (scale_attn * (i + 1)));
attnProbs.Reshape(outputSize);

// 1.2.1 Mask
if (attentionMask.dims.size() != 0) {
AttentionMask(attnProbs, attentionMask, -10000);
Expand All @@ -255,6 +254,8 @@ namespace fastllm {
attnProbs.Reshape({pastValue.dims[0], -1, attnProbs.dims[2]});
MatMul(attnProbs, pastValue, contextLayer);
*/

contextLayer.Reshape({batch, num_attention_heads, maxLen, -1});
PermuteSelf(contextLayer, {2, 0, 1, 3});
contextLayer.Reshape({contextLayer.dims[0], contextLayer.dims[1], embed_dim});
Expand Down

0 comments on commit 7d6ebe8

Please sign in to comment.