diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 7e6cb74..59d21da 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -304,6 +304,7 @@ namespace fastllm { } void Data::CopyFrom(const Data &ori) { + this->ToDevice(ori.dataDevice); this->name = ori.name; this->isKVCache = ori.isKVCache; this->cacheUid = ori.cacheUid; diff --git a/src/models/minicpm3.cpp b/src/models/minicpm3.cpp index d3e0940..ef92c43 100644 --- a/src/models/minicpm3.cpp +++ b/src/models/minicpm3.cpp @@ -151,7 +151,6 @@ namespace fastllm { Cat(q_nope, q_rope, -1, query_states); k_rope.Reshape({bsz, seqlen * qk_rope_head_dim}); - k_rope_expand.ToDevice(DataDevice::CUDA); k_rope_expand.CopyFrom(k_rope); k_rope_expand.Expansion({bsz, num_attention_heads * seqlen * qk_rope_head_dim}); for (int i = 1; i < num_attention_heads; i++) @@ -353,7 +352,6 @@ namespace fastllm { Cat(q_nope, q_rope, -1, query_states); k_rope.Reshape({bsz, seqlen * qk_rope_head_dim}); - k_rope_expand.ToDevice(DataDevice::CUDA); k_rope_expand.CopyFrom(k_rope); k_rope_expand.Expansion({bsz, num_attention_heads * seqlen * qk_rope_head_dim}); for (int i = 1; i < num_attention_heads; i++)