diff --git a/include/models/xlmroberta.h b/include/models/xlmroberta.h index c69cf74..a98846f 100644 --- a/include/models/xlmroberta.h +++ b/include/models/xlmroberta.h @@ -28,8 +28,6 @@ namespace fastllm { const Data &positionIds, bool normalize); - void WarmUp(); // 预热 - std::string model_type; float layer_norm_eps = 1e-12; diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 66fa054..2d5a4ef 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -95,7 +95,10 @@ namespace fastllm { PermuteSelf(k, {0, 2, 1, 3}); PermuteSelf(v, {0, 2, 1, 3}); MatMulTransB(q, k, qk, 1.0 / sqrt(this->head_dim), 1); - AttentionExtendedMask(qk, attentionMask); + std::vector dims = qk.dims; + qk.Resize({dims[0], -1, dims[3]}); + AttentionMask(qk, attentionMask, -1e9); + qk.Resize(dims); Softmax(qk, qk, -1); MatMul(qk, v, qkv, 1.0, 1); @@ -150,7 +153,7 @@ namespace fastllm { std::vector ids = std::vector (batch * len, 0.0f); std::vector seqLens = std::vector (batch, 0.0f); std::vector token_type_ids = std::vector (batch * len, 0.0f); - std::vector attention_mask = std::vector (batch * len, -1e10f); + std::vector attention_mask = std::vector (batch * len, 1); std::vector position_ids = std::vector (batch * len, 0.0f); for (int i = 0; i < batch; i++) { seqLens[i] = tokens[i].size(); diff --git a/src/models/xlmroberta.cpp b/src/models/xlmroberta.cpp index 0c03bc7..09cec7b 100644 --- a/src/models/xlmroberta.cpp +++ b/src/models/xlmroberta.cpp @@ -50,7 +50,7 @@ namespace fastllm { std::vector ids = std::vector (batch * len, 0.0f); std::vector seqLens = std::vector (batch, 0.0f); std::vector token_type_ids = std::vector (batch * len, 0.0f); - std::vector attention_mask = std::vector (batch * len, -1e10f); + std::vector attention_mask = std::vector (batch * len, 1); std::vector position_ids = std::vector (batch * len, 0.0f); for (int i = 0; i < batch; i++) { seqLens[i] = tokens[i].size(); @@ -66,19 +66,6 @@ namespace fastllm { positionIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids)); } - void Normalize__(float *data, int dataLen) - { - float sum = 0.0; - for(int i = 0; i < dataLen; i++) - sum += data[i] * data[i]; - - if (sum < 1e-6) sum = 1e-6; - else sum = sqrt(sum); - - for(int i = 0; i < dataLen; i++) - data[i] = data[i] / sum; - } - std::vector > XlmRobertaModel::ForwardAll( const Data &inputIds, const Data &attentionMask, @@ -93,6 +80,7 @@ namespace fastllm { AddTo(inputEmbeddings, positionIdEmbeddings); Data hiddenStates, firstStates; LayerNorm(inputEmbeddings, this->weight["roberta.embeddings.LayerNorm.weight"], this->weight["roberta.embeddings.LayerNorm.bias"], -1, hiddenStates); + int bsz = hiddenStates.dims[0], seqlen = hiddenStates.dims[1]; Data q, k, v, qk, qkv, attnOutput, inter, pooler, logits; for (int i = 0; i < this->block_cnt; i++) { std::string queryWeightName = "roberta.encoder.layer." + std::to_string(i) + ".attention.self.query.weight"; @@ -123,15 +111,27 @@ namespace fastllm { PermuteSelf(q, {0, 2, 1, 3}); PermuteSelf(k, {0, 2, 1, 3}); PermuteSelf(v, {0, 2, 1, 3}); - MatMulTransB(q, k, qk, 1.0 / sqrt(this->head_dim), 1); - AttentionExtendedMask(qk, attentionMask); - - Softmax(qk, qk, -1); - MatMul(qk, v, qkv, 1.0, 1); - - PermuteSelf(qkv, {0, 2, 1, 3}); - qkv.Reshape({qkv.dims[0], qkv.dims[1], -1}); + if (bsz == 1) { + q.Reshape({-1, q.dims[2], q.dims[3]}); + k.Reshape({-1, k.dims[2], k.dims[3]}); + v.Reshape({-1, v.dims[2], v.dims[3]}); + Attention(q, k, v, Data(), qkv, q.dims[0] / k.dims[0], 1.0 / sqrt(this->head_dim), 1); + PermuteSelf(qkv, {1, 0, 2}); + qkv.Reshape({seqlen, bsz, -1}); + PermuteSelf(qkv, {1, 0, 2}); + } else { + MatMulTransB(q, k, qk, 1.0 / sqrt(this->head_dim), 1); + std::vector dims = qk.dims; + qk.Reshape({dims[0], -1, dims[3]}); + AttentionMask(qk, attentionMask, -1e9); + qk.Reshape(dims); + Softmax(qk, qk, -1); + MatMul(qk, v, qkv, 1.0, 1); + PermuteSelf(qkv, {0, 2, 1, 3}); + qkv.Reshape({qkv.dims[0], qkv.dims[1], -1}); + } + Linear(qkv, this->weight[attnOutputWeightName], this->weight[attnOutputbiasName], attnOutput); AddTo(hiddenStates, attnOutput); LayerNorm(hiddenStates, this->weight[attnLNWeightName], this->weight[attnLNbiasName], -1, hiddenStates); @@ -171,10 +171,4 @@ namespace fastllm { } return ret; } - - void XlmRobertaModel::WarmUp() { - // printf("Warmup...\n"); - // EmbeddingSentence({"1"}); - // printf("finish.\n"); - } } \ No newline at end of file