diff --git a/src/models/xlmroberta.cpp b/src/models/xlmroberta.cpp index 77d2b71..5e37524 100644 --- a/src/models/xlmroberta.cpp +++ b/src/models/xlmroberta.cpp @@ -50,18 +50,18 @@ namespace fastllm { std::vector ids = std::vector (batch * len, 0.0f); std::vector seqLens = std::vector (batch, 0.0f); std::vector token_type_ids = std::vector (batch * len, 0.0f); - std::vector attention_mask = std::vector (batch * len, 1); + std::vector attention_mask = std::vector (batch * len * len, 1); std::vector position_ids = std::vector (batch * len, 0.0f); for (int i = 0; i < batch; i++) { seqLens[i] = tokens[i].size(); for (int j = 0; j < tokens[i].size(); j++) { ids[i * len + j] = tokens[i][j]; - attention_mask[i * len + j] = 0; position_ids[i * len + j] = 2 + j; - } + std::fill(&attention_mask[i * len * len + j * len], &attention_mask[i * len * len + j * len + tokens[i].size()], 0.0f); + } } inputIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids)); - attentionMask.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask)); + attentionMask.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len, len}, attention_mask)); tokenTypeIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids)); positionIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids)); } @@ -112,13 +112,11 @@ namespace fastllm { PermuteSelf(k, {0, 2, 1, 3}); PermuteSelf(v, {0, 2, 1, 3}); - if (false) { - // TODO: 这里使用的AttentionMask不是因果Mask,无法直接调用Attention函数 - // 后续需要修改AttentionMask使得可以直接调用Attention函数 + if (true) { q.Reshape({-1, q.dims[2], q.dims[3]}); k.Reshape({-1, k.dims[2], k.dims[3]}); v.Reshape({-1, v.dims[2], v.dims[3]}); - Attention(q, k, v, Data(), qkv, q.dims[0] / k.dims[0], 1.0 / sqrt(this->head_dim), 1); + Attention(q, k, v, attentionMask, qkv, q.dims[0] / k.dims[0], 1.0 / sqrt(this->head_dim), 1); PermuteSelf(qkv, {1, 0, 2}); qkv.Reshape({seqlen, bsz, -1}); PermuteSelf(qkv, {1, 0, 2});