Skip to content

Commit

Permalink
Tweaks to track device in syncedmem
Browse files Browse the repository at this point in the history
  • Loading branch information
mhouston authored and cypof committed Jul 14, 2015
1 parent abdb736 commit e3f59d3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
5 changes: 3 additions & 2 deletions include/caffe/syncedmem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ class SyncedMemory {
public:
SyncedMemory()
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED),
own_cpu_data_(false), own_gpu_data_(false) {}
own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {}
explicit SyncedMemory(size_t size)
: cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED),
own_cpu_data_(false), own_gpu_data_(false) {}
own_cpu_data_(false), own_gpu_data_(false), gpu_device_(-1) {}
~SyncedMemory();
const void* cpu_data();
void set_cpu_data(void* data);
Expand All @@ -73,6 +73,7 @@ class SyncedMemory {
SyncedHead head_;
bool own_cpu_data_;
bool own_gpu_data_;
int gpu_device_;

DISABLE_COPY_AND_ASSIGN(SyncedMemory);
}; // class SyncedMemory
Expand Down
15 changes: 15 additions & 0 deletions src/caffe/syncedmem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ SyncedMemory::~SyncedMemory() {

#ifndef CPU_ONLY
if (gpu_ptr_ && own_gpu_data_) {
int initial_device;
cudaGetDevice(&initial_device);
if (gpu_device_ != -1) {
CUDA_CHECK(cudaSetDevice(gpu_device_));
}
CUDA_CHECK(cudaFree(gpu_ptr_));
cudaSetDevice(initial_device);
}
#endif // CPU_ONLY
}
Expand Down Expand Up @@ -48,13 +54,15 @@ inline void SyncedMemory::to_gpu() {
#ifndef CPU_ONLY
switch (head_) {
case UNINITIALIZED:
CUDA_CHECK(cudaGetDevice(&gpu_device_));
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
caffe_gpu_memset(size_, 0, gpu_ptr_);
head_ = HEAD_AT_GPU;
own_gpu_data_ = true;
break;
case HEAD_AT_CPU:
if (gpu_ptr_ == NULL) {
CUDA_CHECK(cudaGetDevice(&gpu_device_));
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
own_gpu_data_ = true;
}
Expand Down Expand Up @@ -98,7 +106,13 @@ void SyncedMemory::set_gpu_data(void* data) {
#ifndef CPU_ONLY
CHECK(data);
if (own_gpu_data_) {
int initial_device;
cudaGetDevice(&initial_device);
if (gpu_device_ != -1) {
CUDA_CHECK(cudaSetDevice(gpu_device_));
}
CUDA_CHECK(cudaFree(gpu_ptr_));
cudaSetDevice(initial_device);
}
gpu_ptr_ = data;
head_ = HEAD_AT_GPU;
Expand Down Expand Up @@ -128,6 +142,7 @@ void* SyncedMemory::mutable_gpu_data() {
void SyncedMemory::async_gpu_push(const cudaStream_t& stream) {
CHECK(head_ == HEAD_AT_CPU);
if (gpu_ptr_ == NULL) {
CUDA_CHECK(cudaGetDevice(&gpu_device_));
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
own_gpu_data_ = true;
}
Expand Down

0 comments on commit e3f59d3

Please sign in to comment.