diff --git a/src/fastllm.cpp b/src/fastllm.cpp index eab1d29e..289b7d98 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -306,15 +306,24 @@ namespace fastllm { this->name = ori.name; this->isKVCache = ori.isKVCache; this->cacheUid = ori.cacheUid; + this->dataDevice = ori.dataDevice; // std::cout<<"调用拷贝构造"<expansionDims || ori.dims != this->dims || this->cpuData == nullptr || ori.dataType != this->dataType) { if (ori.dims.size() == 0) { - delete[] this->cpuData; this->dataType = ori.dataType; this->UpdateUnitSize(); this->dims.resize(0); - this->cpuData = nullptr; + + if (this->dataDevice == DataDevice::CPU) { + delete[] this->cpuData; + this->cpuData = nullptr; + } else if (this->dataDevice == DataDevice::CUDA) { +#ifdef USE_CUDA + FastllmCudaFree(this->cudaData); + this->cudaData = nullptr; +#endif + } return; } this->dataType = ori.dataType; @@ -327,7 +336,14 @@ namespace fastllm { this->Allocate(); } } - std::memcpy(this->cpuData, ori.cpuData, this->GetBytes()); + + if (this->dataDevice == DataDevice::CPU) { + std::memcpy(this->cpuData, ori.cpuData, this->GetBytes()); + } else if (this->dataDevice == DataDevice::CUDA) { +#ifdef USE_CUDA + FastllmCudaCopyFromDeviceToDevice(this->cudaData, ori.cudaData, this->GetBytes()); +#endif + } } struct BF16ToFP16Manager {