From f6c25e0996def00152a14d4767f0f6854669e05d Mon Sep 17 00:00:00 2001 From: Sheng Huang Date: Tue, 15 Oct 2024 14:55:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=8D=E5=90=8Cdevice?= =?UTF-8?q?=E4=B9=8B=E9=97=B4=E4=BD=BF=E7=94=A8CopyFrom=E4=BC=9A=E5=87=BA?= =?UTF-8?q?=E9=94=99=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/fastllm.cpp | 1 + src/models/minicpm3.cpp | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 7e6cb74..59d21da 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -304,6 +304,7 @@ namespace fastllm { } void Data::CopyFrom(const Data &ori) { + this->ToDevice(ori.dataDevice); this->name = ori.name; this->isKVCache = ori.isKVCache; this->cacheUid = ori.cacheUid; diff --git a/src/models/minicpm3.cpp b/src/models/minicpm3.cpp index d3e0940..ef92c43 100644 --- a/src/models/minicpm3.cpp +++ b/src/models/minicpm3.cpp @@ -151,7 +151,6 @@ namespace fastllm { Cat(q_nope, q_rope, -1, query_states); k_rope.Reshape({bsz, seqlen * qk_rope_head_dim}); - k_rope_expand.ToDevice(DataDevice::CUDA); k_rope_expand.CopyFrom(k_rope); k_rope_expand.Expansion({bsz, num_attention_heads * seqlen * qk_rope_head_dim}); for (int i = 1; i < num_attention_heads; i++) @@ -353,7 +352,6 @@ namespace fastllm { Cat(q_nope, q_rope, -1, query_states); k_rope.Reshape({bsz, seqlen * qk_rope_head_dim}); - k_rope_expand.ToDevice(DataDevice::CUDA); k_rope_expand.CopyFrom(k_rope); k_rope_expand.Expansion({bsz, num_attention_heads * seqlen * qk_rope_head_dim}); for (int i = 1; i < num_attention_heads; i++)