From 55902fd2d2330c9893dada3bfbcfbcefdb06d5c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Wed, 10 Jul 2024 11:14:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96chatglm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/chatglm.cpp | 109 ++++++++--------------------------------- 1 file changed, 20 insertions(+), 89 deletions(-) diff --git a/src/models/chatglm.cpp b/src/models/chatglm.cpp index a6e0e8b1..dccffa56 100644 --- a/src/models/chatglm.cpp +++ b/src/models/chatglm.cpp @@ -580,110 +580,41 @@ namespace fastllm { CatDirectBatch(keys, pointersK, 1); CatDirectBatch(values, pointersV, 1); if (all1 && batch > 1) { + contextLayer.ToDevice(q.dataDevice); + contextLayer.Resize({batch, 1, embed_dim}); + contextLayer.Allocate(); for (int b = 0; b < batch; b++) { qs[b] = (&curQs[b]); keys[b] = (pastKeyValues[b * block_cnt + i].first); values[b] = (pastKeyValues[b * block_cnt + i].second); masks[b] = attentionMask[b]; + curContextLayer[b].FakeFrom(contextLayer, b * embed_dim * contextLayer.unitSize); contexts[b] = (&curContextLayer[b]); outputSizes[b] = {1, qs[b]->dims[0], qs[b]->dims[1], keys[b]->dims[1]}; } AttentionBatch(qs, keys, values, masks, contexts, qs[0]->dims[0] / values[0]->dims[0], 1.0 / scale_attn, 1); } else { + contextLayer.ToDevice(curQs[0].dataDevice); + contextLayer.Resize({total, 1, embed_dim}); + contextLayer.Allocate(); + int curLen = 0; for (int b = 0; b < batch; b++) { - auto &q = curQs[b]; - Data &pastKey = *pastKeyValues[b * block_cnt + i].first; - outputSizes[b] = {1, q.dims[0], q.dims[1], pastKey.dims[1]}; - q.Reshape({pastKey.dims[0], -1, q.dims[2]}); - } - - // 1.2 Attention - // 1.2.0 q * k^T - if (all1 && batch > 1) { - for (int b = 0; b < batch; b++) { - qs[b] = (&curQs[b]); - keys[b] = (pastKeyValues[b * block_cnt + i].first); - attns[b] = (&attnProbs[b]); - } - MatMulTransBBatch(qs, keys, attns, 1.0 / (scale_attn * (i + 1))); - } else { - for (int b = 0; b < batch; b++) { - auto &q = curQs[b]; - Data &pastKey = *pastKeyValues[b * block_cnt + i].first; - MatMulTransB(q, pastKey, attnProbs[b], 1.0 / (scale_attn * (i + 1))); - } - } - - for (int b = 0; b < batch; b++) { - attnProbs[b].Reshape(outputSizes[b]); - // 1.2.1 Mask - if (attentionMask[b] != nullptr) { - AttentionMask(attnProbs[b], *attentionMask[b], -10000); - } - } - - // 1.2.2 softmax - for (int i = 0; i < attnProbs.size(); i++) { - attns[i] = (&attnProbs[i]); - } - MulBatch(attns, i + 1, attns); - SoftmaxBatch(attns, attns, -1); - - for (int b = 0; b < batch; b++) { - Data &pastValue = *pastKeyValues[b * block_cnt + i].second; - outputSizes[b] = {1, num_attention_heads, -1, pastValue.dims[2]}; - attnProbs[b].Reshape({pastValue.dims[0], -1, attnProbs[b].dims[3]}); - } - - // 1.2.3 prob * v - if (all1 && batch > 1) { - for (int b = 0; b < batch; b++) { - attns[b] = (&attnProbs[b]); - values[b] = (pastKeyValues[b * block_cnt + i].second); - contexts[b] = (&curContextLayer[b]); - } - MatMulBatch(attns, values, contexts); - } else { - for (int b = 0; b < batch; b++) { - Data &pastValue = *pastKeyValues[b * block_cnt + i].second; - MatMul(attnProbs[b], pastValue, curContextLayer[b]); - } - } - } - if (all1) { - for (int b = 0; b < batch; b++) { - curContextLayer[b].dims[0] = outputSizes[b][2]; - curContextLayer[b].dims[1] = outputSizes[b][0]; - curContextLayer[b].dims[2] = embed_dim; - curContextLayer[b].strides[0] = curContextLayer[b].dims[1] * curContextLayer[b].dims[2]; - curContextLayer[b].strides[1] = curContextLayer[b].dims[2]; - curContextLayer[b].strides[2] = 1; - } - } else { - for (int b = 0; b < batch; b++) { - curContextLayer[b].Reshape(outputSizes[b]); - PermuteSelf(curContextLayer[b], {2, 0, 1, 3}); - curContextLayer[b].Reshape({curContextLayer[b].dims[0], curContextLayer[b].dims[1], embed_dim}); - } - } - - if (all1 && batch > 1) { - for (int b = 0; b < batch; b++) { - contexts[b] = (&curContextLayer[b]); - } - CatBatch(contexts, 0, contextLayer); - } else { - for (int b = 0; b < batch; b++) { - if (contextLayer.dims.size() == 0) { - std::vector dims = curContextLayer[b].dims; - dims[0] = total; - contextLayer.Expansion(dims); + auto &q = curQs[b], &k = curKs[b], &v = curVs[b]; + Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt + i].second; + curContextLayer[0].FakeFrom(contextLayer, curLen * embed_dim * contextLayer.unitSize); + curLen += seqLens[b]; + + // 1.2 Attention + if (attentionMask[b] == nullptr) { + Attention(q, pastKey, pastValue, Data(), curContextLayer[0], q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1); + } else { + Attention(q, pastKey, pastValue, *attentionMask[b], curContextLayer[0], q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1); } - contextLayer.ToDevice(DataDevice::CUDA); - CatDirect(contextLayer, curContextLayer[b], 0); + PermuteSelf(curContextLayer[0], {1, 0, 2}); } } + // 1.2.4 dense std::string denseWeightName = weightPre + std::to_string(i) + weightMiddle + ".dense.weight"; std::string denseBiasName = weightPre + std::to_string(i) + weightMiddle + ".dense.bias";