Skip to content

Commit

Permalink
fix copyfrom for cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 13, 2024
1 parent 1952607 commit 112ce28
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,24 @@ namespace fastllm {
this->name = ori.name;
this->isKVCache = ori.isKVCache;
this->cacheUid = ori.cacheUid;
this->dataDevice = ori.dataDevice;

// std::cout<<"调用拷贝构造"<<std::endl;
if (ori.expansionDims != this->expansionDims || ori.dims != this->dims || this->cpuData == nullptr || ori.dataType != this->dataType) {
if (ori.dims.size() == 0) {
delete[] this->cpuData;
this->dataType = ori.dataType;
this->UpdateUnitSize();
this->dims.resize(0);
this->cpuData = nullptr;

if (this->dataDevice == DataDevice::CPU) {
delete[] this->cpuData;
this->cpuData = nullptr;
} else if (this->dataDevice == DataDevice::CUDA) {
#ifdef USE_CUDA
FastllmCudaFree(this->cudaData);
this->cudaData = nullptr;
#endif
}
return;
}
this->dataType = ori.dataType;
Expand All @@ -327,7 +336,14 @@ namespace fastllm {
this->Allocate();
}
}
std::memcpy(this->cpuData, ori.cpuData, this->GetBytes());

if (this->dataDevice == DataDevice::CPU) {
std::memcpy(this->cpuData, ori.cpuData, this->GetBytes());
} else if (this->dataDevice == DataDevice::CUDA) {
#ifdef USE_CUDA
FastllmCudaCopyFromDeviceToDevice(this->cudaData, ori.cudaData, this->GetBytes());
#endif
}
}

struct BF16ToFP16Manager {
Expand Down

0 comments on commit 112ce28

Please sign in to comment.