Skip to content

Commit

Permalink
fix xlmberta
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 26, 2024
1 parent 5912fd3 commit 3d61022
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/models/xlmroberta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,18 @@ namespace fastllm {
std::vector <float> ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> seqLens = std::vector <float> (batch, 0.0f);
std::vector <float> token_type_ids = std::vector <float> (batch * len, 0.0f);
std::vector <float> attention_mask = std::vector <float> (batch * len, 1);
std::vector <float> attention_mask = std::vector <float> (batch * len * len, 1);
std::vector <float> position_ids = std::vector <float> (batch * len, 0.0f);
for (int i = 0; i < batch; i++) {
seqLens[i] = tokens[i].size();
for (int j = 0; j < tokens[i].size(); j++) {
ids[i * len + j] = tokens[i][j];
attention_mask[i * len + j] = 0;
position_ids[i * len + j] = 2 + j;
}
std::fill(&attention_mask[i * len * len + j * len], &attention_mask[i * len * len + j * len + tokens[i].size()], 0.0f);
}
}
inputIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids));
attentionMask.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask));
attentionMask.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len, len}, attention_mask));
tokenTypeIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids));
positionIds.CopyFrom(fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids));
}
Expand Down Expand Up @@ -112,13 +112,11 @@ namespace fastllm {
PermuteSelf(k, {0, 2, 1, 3});
PermuteSelf(v, {0, 2, 1, 3});

if (false) {
// TODO: 这里使用的AttentionMask不是因果Mask,无法直接调用Attention函数
// 后续需要修改AttentionMask使得可以直接调用Attention函数
if (true) {
q.Reshape({-1, q.dims[2], q.dims[3]});
k.Reshape({-1, k.dims[2], k.dims[3]});
v.Reshape({-1, v.dims[2], v.dims[3]});
Attention(q, k, v, Data(), qkv, q.dims[0] / k.dims[0], 1.0 / sqrt(this->head_dim), 1);
Attention(q, k, v, attentionMask, qkv, q.dims[0] / k.dims[0], 1.0 / sqrt(this->head_dim), 1);
PermuteSelf(qkv, {1, 0, 2});
qkv.Reshape({seqlen, bsz, -1});
PermuteSelf(qkv, {1, 0, 2});
Expand Down

0 comments on commit 3d61022

Please sign in to comment.