Skip to content

Commit

Permalink
修复不同device之间使用CopyFrom会出错的bug
Browse files Browse the repository at this point in the history
  • Loading branch information
huangsheng-tf committed Oct 15, 2024
1 parent 79e40f5 commit f6c25e0
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 2 deletions.
1 change: 1 addition & 0 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ namespace fastllm {
}

void Data::CopyFrom(const Data &ori) {
this->ToDevice(ori.dataDevice);
this->name = ori.name;
this->isKVCache = ori.isKVCache;
this->cacheUid = ori.cacheUid;
Expand Down
2 changes: 0 additions & 2 deletions src/models/minicpm3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ namespace fastllm {
Cat(q_nope, q_rope, -1, query_states);

k_rope.Reshape({bsz, seqlen * qk_rope_head_dim});
k_rope_expand.ToDevice(DataDevice::CUDA);
k_rope_expand.CopyFrom(k_rope);
k_rope_expand.Expansion({bsz, num_attention_heads * seqlen * qk_rope_head_dim});
for (int i = 1; i < num_attention_heads; i++)
Expand Down Expand Up @@ -353,7 +352,6 @@ namespace fastllm {
Cat(q_nope, q_rope, -1, query_states);

k_rope.Reshape({bsz, seqlen * qk_rope_head_dim});
k_rope_expand.ToDevice(DataDevice::CUDA);
k_rope_expand.CopyFrom(k_rope);
k_rope_expand.Expansion({bsz, num_attention_heads * seqlen * qk_rope_head_dim});
for (int i = 1; i < num_attention_heads; i++)
Expand Down

0 comments on commit f6c25e0

Please sign in to comment.