From 51d9686a697550d25d1e9c77492c31c8d2c55853 Mon Sep 17 00:00:00 2001 From: cypof Date: Fri, 20 Feb 2015 14:57:28 -0800 Subject: [PATCH] Data queues, prefetching and multi-source --- Makefile | 4 +- include/caffe/common.hpp | 4 +- include/caffe/data_layers.hpp | 92 +++++++-- include/caffe/internal_thread.hpp | 9 +- include/caffe/syncedmem.hpp | 4 + include/caffe/util/blocking_queue.hpp | 50 +++++ src/caffe/common.cpp | 10 + src/caffe/internal_thread.cpp | 9 +- src/caffe/layers/base_data_layer.cpp | 97 +++++++--- src/caffe/layers/base_data_layer.cu | 15 +- src/caffe/layers/data_layer.cpp | 244 ++++++++++++++++++------ src/caffe/layers/image_data_layer.cpp | 26 +-- src/caffe/layers/window_data_layer.cpp | 19 +- src/caffe/proto/caffe.proto | 13 +- src/caffe/syncedmem.cpp | 12 ++ src/caffe/test/test_data_layer.cpp | 175 +++++++++++++---- src/caffe/test/test_internal_thread.cpp | 2 +- src/caffe/util/blocking_queue.cpp | 87 +++++++++ src/caffe/util/db.cpp | 12 +- src/caffe/util/upgrade_proto.cpp | 2 +- 20 files changed, 700 insertions(+), 186 deletions(-) create mode 100644 include/caffe/util/blocking_queue.hpp create mode 100644 src/caffe/util/blocking_queue.cpp diff --git a/Makefile b/Makefile index 2a75d66e02a..cb23d898f82 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ PROJECT := caffe -CONFIG_FILE := Makefile.config +CONFIG_FILE ?= Makefile.config include $(CONFIG_FILE) BUILD_DIR_LINK := $(BUILD_DIR) @@ -267,6 +267,8 @@ endif # Debugging ifeq ($(DEBUG), 1) COMMON_FLAGS += -DDEBUG -g -O0 + # Compile issue in DEBUG on MAC (https://svn.boost.org/trac/boost/ticket/9392) + COMMON_FLAGS += -DBOOST_NOINLINE='__attribute__ ((noinline))' NVCCFLAGS += -G else COMMON_FLAGS += -DNDEBUG -O2 diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 890673cd7e6..a69c3bb0f84 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -141,7 +141,8 @@ class Caffe { // freed in a non-pinned way, which may cause problems - I haven't verified // it personally but better to note it here in the header file. inline static void set_mode(Brew mode) { Get().mode_ = mode; } - // Sets the random seed of both boost and curand + // Random seed of both boost and curand + static unsigned int get_random_seed(); static void set_random_seed(const unsigned int seed); // Sets the device. Since we have cublas and curand stuff, set device also // requires us to reset those values. @@ -155,6 +156,7 @@ class Caffe { curandGenerator_t curand_generator_; #endif shared_ptr random_generator_; + unsigned int random_generator_seed_; Brew mode_; static shared_ptr singleton_; diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 1f154408c27..2ccb43849c8 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -1,11 +1,15 @@ #ifndef CAFFE_DATA_LAYERS_HPP_ #define CAFFE_DATA_LAYERS_HPP_ +#include #include #include #include -#include "boost/scoped_ptr.hpp" +#include "boost/random/mersenne_twister.hpp" +#include "boost/random/uniform_real.hpp" +#include "boost/random/variate_generator.hpp" +#include "boost/weak_ptr.hpp" #include "hdf5.h" #include "caffe/blob.hpp" @@ -16,10 +20,16 @@ #include "caffe/layer.hpp" #include "caffe/net.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/blocking_queue.hpp" #include "caffe/util/db.hpp" namespace caffe { +using boost::weak_ptr; +using boost::mt19937; +using boost::uniform_real; +using boost::variate_generator; + /** * @brief Provides base for data layers that feed blobs to the Net. * @@ -52,12 +62,17 @@ class BaseDataLayer : public Layer { bool output_labels_; }; +template +class Batch { + public: + Blob data_, label_; +}; + template class BasePrefetchingDataLayer : public BaseDataLayer, public InternalThread { public: - explicit BasePrefetchingDataLayer(const LayerParameter& param) - : BaseDataLayer(param) {} + explicit BasePrefetchingDataLayer(const LayerParameter& param); virtual ~BasePrefetchingDataLayer() {} // LayerSetUp: implements common data layer setup functionality, and calls // DataLayerSetUp to do special data layer setup for individual layer types. @@ -70,22 +85,63 @@ class BasePrefetchingDataLayer : virtual void Forward_gpu(const vector*>& bottom, const vector*>& top); - virtual void CreatePrefetchThread(); - virtual void JoinPrefetchThread(); - // The thread's function - virtual void InternalThreadEntry() {} + // Prefetches batches (asynchronously if to GPU memory) + static const int PREFETCH_COUNT = 3; protected: - Blob prefetch_data_; - Blob prefetch_label_; + virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch) = 0; + + Batch prefetch_[PREFETCH_COUNT]; + blocking_queue*> prefetch_free_; + blocking_queue*> prefetch_full_; + int device_; + Blob transformed_data_; }; +// Prefetches datums to host memory that can be read by multiple data layers. +class DataLoader { + public: + DataLoader(const DataParameter& param, int index); + ~DataLoader(); + + inline blocking_queue& free() { + return body_.get()->free_; + } + inline blocking_queue& full() { + return body_.get()->full_; + } + + protected: + class Body: public InternalThread { + public: + Body(const DataParameter& param, int index); + ~Body(); + + void InternalThreadEntry(); + + shared_ptr db_; + shared_ptr cursor_; + + blocking_queue free_; + blocking_queue full_; + + DISABLE_COPY_AND_ASSIGN(Body); + }; + + static map > instances_; + + const string source_; + shared_ptr body_; + + DISABLE_COPY_AND_ASSIGN(DataLoader); +}; + template -class DataLayer : public BasePrefetchingDataLayer { +class DataLayer: public BasePrefetchingDataLayer { public: - explicit DataLayer(const LayerParameter& param) - : BasePrefetchingDataLayer(param) {} + explicit DataLayer(const LayerParameter& param); virtual ~DataLayer(); virtual void DataLayerSetUp(const vector*>& bottom, const vector*>& top); @@ -96,10 +152,12 @@ class DataLayer : public BasePrefetchingDataLayer { virtual inline int MaxTopBlobs() const { return 2; } protected: - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); + DataLoader* next_loader(); - shared_ptr db_; - shared_ptr cursor_; + vector > loaders_; + mt19937 rand_engine_; + uniform_real rand_; }; /** @@ -236,7 +294,7 @@ class ImageDataLayer : public BasePrefetchingDataLayer { protected: shared_ptr prefetch_rng_; virtual void ShuffleImages(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); vector > lines_; int lines_id_; @@ -308,7 +366,7 @@ class WindowDataLayer : public BasePrefetchingDataLayer { protected: virtual unsigned int PrefetchRand(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); shared_ptr prefetch_rng_; vector > > image_database_; diff --git a/include/caffe/internal_thread.hpp b/include/caffe/internal_thread.hpp index 815ca54605e..11c24ca98c7 100644 --- a/include/caffe/internal_thread.hpp +++ b/include/caffe/internal_thread.hpp @@ -18,23 +18,28 @@ namespace caffe { */ class InternalThread { public: - InternalThread() : thread_() {} + InternalThread() : thread_(), must_stop_() {} virtual ~InternalThread(); /** Returns true if the thread was successfully started. **/ bool StartInternalThread(); /** Will not return until the internal thread has exited. */ - bool WaitForInternalThreadToExit(); + bool StopInternalThread(); bool is_started() const; + bool must_stop() { + return must_stop_; + } + protected: /* Implement this method in your subclass with the code you want your thread to run. */ virtual void InternalThreadEntry() {} shared_ptr thread_; + bool must_stop_; }; } // namespace caffe diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp index 2564e0716ef..a12709be954 100644 --- a/include/caffe/syncedmem.hpp +++ b/include/caffe/syncedmem.hpp @@ -56,6 +56,10 @@ class SyncedMemory { SyncedHead head() { return head_; } size_t size() { return size_; } +#ifndef CPU_ONLY + void async_gpu_push(const cudaStream_t& stream); +#endif + private: void to_cpu(); void to_gpu(); diff --git a/include/caffe/util/blocking_queue.hpp b/include/caffe/util/blocking_queue.hpp new file mode 100644 index 00000000000..96e83a1f105 --- /dev/null +++ b/include/caffe/util/blocking_queue.hpp @@ -0,0 +1,50 @@ +#ifndef CAFFE_UTIL_BLOCKING_QUEUE_H_ +#define CAFFE_UTIL_BLOCKING_QUEUE_H_ + +#include +#include + +#include "caffe/common.hpp" + +namespace caffe { + +template +class blocking_queue { + public: + explicit blocking_queue(); + virtual ~blocking_queue(); + + void push(const T& t); + + bool empty() const; + + bool try_pop(T* t); + + T pop(const string& log_on_wait = ""); + + // Return element without removing it + T peek(); + + inline uint64_t pops() { + return pops_; + } + + protected: + /** + Move synchronization fields out instead of including boost/thread.hpp + to avoid a boost/NVCC issues (#1009, #1010) on OSX. Also fails on + Linux CUDA 7.0.18. + */ + class sync; + + std::queue queue_; + shared_ptr sync_; + time_t last_wait_log_; + uint64_t pops_; + +DISABLE_COPY_AND_ASSIGN(blocking_queue); +}; + +} // namespace caffe + +#endif diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index af96cac40aa..ff73344b225 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -46,9 +46,14 @@ Caffe::Caffe() Caffe::~Caffe() { } +unsigned int Caffe::get_random_seed() { + return Get().random_generator_seed_; +} + void Caffe::set_random_seed(const unsigned int seed) { // RNG seed Get().random_generator_.reset(new RNG(seed)); + Get().random_generator_seed_ = seed; } void Caffe::SetDevice(const int device_id) { @@ -108,6 +113,10 @@ Caffe::~Caffe() { } } +unsigned int Caffe::get_random_seed() { + return Get().random_generator_seed_; +} + void Caffe::set_random_seed(const unsigned int seed) { // Curand seed static bool g_curand_availability_logged = false; @@ -124,6 +133,7 @@ void Caffe::set_random_seed(const unsigned int seed) { } // RNG seed Get().random_generator_.reset(new RNG(seed)); + Get().random_generator_seed_ = seed; } void Caffe::SetDevice(const int device_id) { diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp index c2d19d433b4..f9d6701a0a1 100644 --- a/src/caffe/internal_thread.cpp +++ b/src/caffe/internal_thread.cpp @@ -4,7 +4,7 @@ namespace caffe { InternalThread::~InternalThread() { - WaitForInternalThreadToExit(); + StopInternalThread(); } bool InternalThread::is_started() const { @@ -13,9 +13,10 @@ bool InternalThread::is_started() const { bool InternalThread::StartInternalThread() { - if (!WaitForInternalThreadToExit()) { + if (!StopInternalThread()) { return false; } + must_stop_ = false; try { thread_.reset( new boost::thread(&InternalThread::InternalThreadEntry, this)); @@ -26,8 +27,10 @@ bool InternalThread::StartInternalThread() { } /** Will not return until the internal thread has exited. */ -bool InternalThread::WaitForInternalThreadToExit() { +bool InternalThread::StopInternalThread() { + must_stop_ = true; if (is_started()) { + thread_->interrupt(); try { thread_->join(); } catch (...) { diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp index 352200915d7..c2bab19377f 100644 --- a/src/caffe/layers/base_data_layer.cpp +++ b/src/caffe/layers/base_data_layer.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -28,54 +29,97 @@ void BaseDataLayer::LayerSetUp(const vector*>& bottom, data_transformer_->InitRand(); } +template +BasePrefetchingDataLayer::BasePrefetchingDataLayer( + const LayerParameter& param) + : BaseDataLayer(param), + prefetch_free_(), prefetch_full_(), device_() { + for (int i = 0; i < PREFETCH_COUNT; ++i) + prefetch_free_.push(&prefetch_[i]); +} + template void BasePrefetchingDataLayer::LayerSetUp( const vector*>& bottom, const vector*>& top) { BaseDataLayer::LayerSetUp(bottom, top); - // Now, start the prefetch thread. Before calling prefetch, we make two - // cpu_data calls so that the prefetch thread does not accidentally make - // simultaneous cudaMalloc calls when the main thread is running. In some - // GPUs this seems to cause failures if we do not so. - this->prefetch_data_.mutable_cpu_data(); - if (this->output_labels_) { - this->prefetch_label_.mutable_cpu_data(); + + // Before starting the prefetch thread, we make cpu_data and gpu_data + // calls so that the prefetch thread does not accidentally make simultaneous + // cudaMalloc calls when the main thread is running. In some GPUs this + // seems to cause failures if we do not so. + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_cpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_cpu_data(); + } + } + switch (Caffe::mode()) { + case Caffe::CPU: + device_ = -1; + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_gpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_gpu_data(); + } + } + CUDA_CHECK(cudaGetDevice(&device_)); +#endif + break; } - DLOG(INFO) << "Initializing prefetch"; - this->CreatePrefetchThread(); - DLOG(INFO) << "Prefetch initialized."; -} -template -void BasePrefetchingDataLayer::CreatePrefetchThread() { + DLOG(INFO) << "Initializing prefetch"; this->data_transformer_->InitRand(); CHECK(StartInternalThread()) << "Thread execution failed"; + DLOG(INFO) << "Prefetch initialized."; } template -void BasePrefetchingDataLayer::JoinPrefetchThread() { - CHECK(WaitForInternalThreadToExit()) << "Thread joining failed"; +void BasePrefetchingDataLayer::InternalThreadEntry() { +#ifndef CPU_ONLY + cudaStream_t stream; + if (device_ >= 0) { + CUDA_CHECK(cudaSetDevice(device_)); + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + } +#endif + + try { + while (!must_stop()) { + Batch* batch = prefetch_free_.pop(); + load_batch(batch); +#ifndef CPU_ONLY + if (device_ >= 0) { + batch->data_.data().get()->async_gpu_push(stream); + cudaStreamSynchronize(stream); + } +#endif + prefetch_full_.push(batch); + } + } catch (boost::thread_interrupted&) { + // Interrupted exception is expected on shutdown + } } template void BasePrefetchingDataLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); - DLOG(INFO) << "Thread joined"; + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->Reshape(this->prefetch_data_.num(), this->prefetch_data_.channels(), - this->prefetch_data_.height(), this->prefetch_data_.width()); + top[0]->Reshape(batch->data_.num(), batch->data_.channels(), + batch->data_.height(), batch->data_.width()); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.cpu_data(), top[0]->mutable_cpu_data()); DLOG(INFO) << "Prefetch copied"; if (this->output_labels_) { - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), - top[1]->mutable_cpu_data()); + caffe_copy(batch->label_.count(), batch->label_.cpu_data(), + top[1]->mutable_cpu_data()); } - // Start a new prefetch thread - DLOG(INFO) << "CreatePrefetchThread"; - CreatePrefetchThread(); + + prefetch_free_.push(batch); } #ifdef CPU_ONLY @@ -83,6 +127,7 @@ STUB_GPU_FORWARD(BasePrefetchingDataLayer, Forward); #endif INSTANTIATE_CLASS(BaseDataLayer); +INSTANTIATE_CLASS(Batch); INSTANTIATE_CLASS(BasePrefetchingDataLayer); } // namespace caffe diff --git a/src/caffe/layers/base_data_layer.cu b/src/caffe/layers/base_data_layer.cu index 775f6c47f7e..52085d007a7 100644 --- a/src/caffe/layers/base_data_layer.cu +++ b/src/caffe/layers/base_data_layer.cu @@ -7,20 +7,19 @@ namespace caffe { template void BasePrefetchingDataLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->Reshape(this->prefetch_data_.num(), this->prefetch_data_.channels(), - this->prefetch_data_.height(), this->prefetch_data_.width()); + top[0]->Reshape(batch->data_.num(), batch->data_.channels(), + batch->data_.height(), batch->data_.width()); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.gpu_data(), top[0]->mutable_gpu_data()); if (this->output_labels_) { - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), + caffe_copy(batch->label_.count(), batch->label_.gpu_data(), top[1]->mutable_gpu_data()); } - // Start a new prefetch thread - CreatePrefetchThread(); + + prefetch_free_.push(batch); } INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 8877caf89c8..aee892e17b7 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -1,7 +1,10 @@ +#include #include #include +#include +#include #include #include @@ -16,100 +19,193 @@ namespace caffe { -template -DataLayer::~DataLayer() { - this->JoinPrefetchThread(); +map > DataLoader::instances_; +static boost::mutex data_loader_instances_mutex_; + +DataLoader::DataLoader(const DataParameter& param, int index): + source_(param.source(index)) { + // Makes sure create only one body per source + boost::mutex::scoped_lock lock(data_loader_instances_mutex_); + weak_ptr body = instances_[source_]; + body_ = body.lock(); + if (!body_) { + body_.reset(new Body(param, index)); + instances_[source_] = weak_ptr(body_); + } } -template -void DataLayer::DataLayerSetUp(const vector*>& bottom, - const vector*>& top) { +DataLoader::~DataLoader() { + boost::mutex::scoped_lock lock(data_loader_instances_mutex_); + body_.reset(); + if (instances_[source_].expired()) + instances_.erase(source_); +} + +DataLoader::Body::Body(const DataParameter& param, int index) { // Initialize DB - db_.reset(db::GetDB(this->layer_param_.data_param().backend())); - db_->Open(this->layer_param_.data_param().source(), db::READ); + DataParameter_DB backend = param.backend_size() ? + param.backend(index) : DataParameter::LEVELDB; + db_.reset(db::GetDB(backend)); + db_->Open(param.source(index), db::READ); cursor_.reset(db_->NewCursor()); // Check if we should randomly skip a few data points - if (this->layer_param_.data_param().rand_skip()) { - unsigned int skip = caffe_rng_rand() % - this->layer_param_.data_param().rand_skip(); + if (param.rand_skip()) { + unsigned int skip = caffe_rng_rand() % param.rand_skip(); LOG(INFO) << "Skipping first " << skip << " data points."; while (skip-- > 0) { cursor_->Next(); } } - // Read a data point, and use it to initialize the top blob. - Datum datum; - datum.ParseFromString(cursor_->value()); + + // Add prefetch datums to layer free queue + int prefetch = param.prefetch() * param.batch_size(); + for (int i = 0; i < prefetch; ++i) { + free_.push(new Datum()); + } + + CHECK(StartInternalThread()) << "DataLoader thread start failed"; +} + +DataLoader::Body::~Body() { + CHECK(StopInternalThread()) << "DataLoader thread stop failed"; + Datum* datum; + while (free_.try_pop(&datum)) { + delete datum; + } + while (full_.try_pop(&datum)) { + delete datum; + } +} + +void DataLoader::Body::InternalThreadEntry() { + try { + while (!must_stop()) { + Datum* datum = free_.pop(); + // TODO deserialize in-place instead of copy? + datum->ParseFromString(cursor_->value()); + full_.push(datum); + + // go to the next iter + cursor_->Next(); + if (!cursor_->valid()) { + DLOG(INFO) << "Restarting data prefetching from start."; + cursor_->SeekToFirst(); + } + } + } catch (boost::thread_interrupted&) { + // Interrupted exception is expected on shutdown + } +} + +static unsigned int get_datalayer_specific_random_seed() { + unsigned int seed = Caffe::get_random_seed(); + if (!seed) { + seed = caffe_rng_rand(); + } + return seed + 87267527; +} + +template +DataLayer::DataLayer(const LayerParameter& param) + : BasePrefetchingDataLayer(param), + rand_engine_(get_datalayer_specific_random_seed()) { + const DataParameter& data = param.data_param(); + if (data.backend_size()) { + CHECK(data.source().size() == data.backend().size()) + << "Invalid DataParameter, there should be one backend per source"; + } + if (data.probability_size()) { + CHECK(data.source().size() == data.backend().size()) + << "Invalid DataParameter, there should be one probability per source"; + float sum = 0; + for (int i = 0; i < data.probability().size(); ++i) { + sum += data.probability(i); + } + CHECK(fabsf(sum - 1.0f) < 1e-6f) + << "Invalid DataParameter, probabilities do not sum to 1"; + } + for (int i = 0; i < data.source().size(); ++i) { + DataLoader* ld = new DataLoader(data, i); + loaders_.push_back(shared_ptr(ld)); + } +} + +template +DataLayer::~DataLayer() { + CHECK(this->StopInternalThread()) << "Stop thread failed"; +} + +template +void DataLayer::DataLayerSetUp(const vector*>& bottom, + const vector*>& top) { + // Look at first data point to initialize the top blob. + Datum* datum = loaders_[0].get()->full().peek(); bool force_color = this->layer_param_.data_param().force_encoded_color(); - if ((force_color && DecodeDatum(&datum, true)) || - DecodeDatumNative(&datum)) { + if ((force_color && DecodeDatum(datum, true)) || + DecodeDatumNative(datum)) { LOG(INFO) << "Decoding Datum"; } // image - int crop_size = this->layer_param_.transform_param().crop_size(); + const int crop_size = this->layer_param_.transform_param().crop_size(); + const int batch_size = this->layer_param_.data_param().batch_size(); if (crop_size > 0) { - top[0]->Reshape(this->layer_param_.data_param().batch_size(), - datum.channels(), crop_size, crop_size); - this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(), - datum.channels(), crop_size, crop_size); - this->transformed_data_.Reshape(1, datum.channels(), crop_size, crop_size); + top[0]->Reshape(batch_size, datum->channels(), crop_size, crop_size); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(batch_size, datum->channels(), + crop_size, crop_size); + } + this->transformed_data_.Reshape(1, datum->channels(), + crop_size, crop_size); } else { - top[0]->Reshape( - this->layer_param_.data_param().batch_size(), datum.channels(), - datum.height(), datum.width()); - this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(), - datum.channels(), datum.height(), datum.width()); - this->transformed_data_.Reshape(1, datum.channels(), - datum.height(), datum.width()); + top[0]->Reshape(batch_size, datum->channels(), + datum->height(), datum->width()); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(batch_size, datum->channels(), + datum->height(), datum->width()); + } + this->transformed_data_.Reshape(1, datum->channels(), + datum->height(), datum->width()); } LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width(); // label if (this->output_labels_) { - top[1]->Reshape(this->layer_param_.data_param().batch_size(), 1, 1, 1); - this->prefetch_label_.Reshape(this->layer_param_.data_param().batch_size(), - 1, 1, 1); + top[1]->Reshape(batch_size, 1, 1, 1); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(batch_size, 1, 1, 1); + } } } -// This function is used to create a thread that prefetches the data. +// This function is called on prefetch thread template -void DataLayer::InternalThreadEntry() { +void DataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); - // Reshape on single input batches for inputs of varying dimension. const int batch_size = this->layer_param_.data_param().batch_size(); const int crop_size = this->layer_param_.transform_param().crop_size(); - if (batch_size == 1 && crop_size == 0) { - Datum datum; - datum.ParseFromString(cursor_->value()); - this->prefetch_data_.Reshape(1, datum.channels(), - datum.height(), datum.width()); - this->transformed_data_.Reshape(1, datum.channels(), - datum.height(), datum.width()); - } - - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* top_label = NULL; // suppress warnings about uninitialized variables - - if (this->output_labels_) { - top_label = this->prefetch_label_.mutable_cpu_data(); - } bool force_color = this->layer_param_.data_param().force_encoded_color(); for (int item_id = 0; item_id < batch_size; ++item_id) { timer.Start(); - // get a blob - Datum datum; - datum.ParseFromString(cursor_->value()); + DataLoader* loader = next_loader(); + const Datum& datum = *(loader->full().pop("Waiting on data loader")); + + // Reshape on single input batches for inputs of varying dimension. + if (batch_size == 1 && crop_size == 0) { + batch->data_.Reshape(1, datum.channels(), + datum.height(), datum.width()); + this->transformed_data_.Reshape(1, datum.channels(), + datum.height(), datum.width()); + } cv::Mat cv_img; if (datum.encoded()) { @@ -129,7 +225,8 @@ void DataLayer::InternalThreadEntry() { timer.Start(); // Apply data transformations (mirror, scale, crop...) - int offset = this->prefetch_data_.offset(item_id); + Dtype* top_data = batch->data_.mutable_cpu_data(); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(top_data + offset); if (datum.encoded()) { this->data_transformer_->Transform(cv_img, &(this->transformed_data_)); @@ -137,15 +234,11 @@ void DataLayer::InternalThreadEntry() { this->data_transformer_->Transform(datum, &(this->transformed_data_)); } if (this->output_labels_) { - top_label[item_id] = datum.label(); + batch->label_.mutable_cpu_data()[item_id] = datum.label(); } trans_time += timer.MicroSeconds(); - // go to the next iter - cursor_->Next(); - if (!cursor_->valid()) { - DLOG(INFO) << "Restarting data prefetching from start."; - cursor_->SeekToFirst(); - } + + loader->free().push(const_cast(&datum)); } batch_timer.Stop(); DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms."; @@ -153,6 +246,33 @@ void DataLayer::InternalThreadEntry() { DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms."; } +// This function is called on prefetch thread +template +DataLoader* DataLayer::next_loader() { + const DataParameter& data = this->layer_param().data_param(); + // Default case without probabilities, try to find a loader with + // data ready, or return first one + if (data.probability_size() == 0) { + for (int i = 0; i < loaders_.size(); ++i) { + DataLoader* loader = loaders_[i].get(); + if (!loader->full().empty()) { + return loader; + } + } + } else { + // Pick loader randomly with probability + float rand = rand_(rand_engine_); + for (int i = 0; i < data.probability().size(); ++i) { + rand -= data.probability(i); + if (rand < 0) { + return loaders_[i].get(); + } + } + } + // If no data ready, or rounding error on probabilities + return loaders_[0].get(); +} + INSTANTIATE_CLASS(DataLayer); REGISTER_LAYER_CLASS(Data); diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp index f9046e1b3a1..164e40f6f03 100644 --- a/src/caffe/layers/image_data_layer.cpp +++ b/src/caffe/layers/image_data_layer.cpp @@ -17,7 +17,7 @@ namespace caffe { template ImageDataLayer::~ImageDataLayer() { - this->JoinPrefetchThread(); + CHECK(this->StopInternalThread()) << "Stop thread failed"; } template @@ -70,11 +70,14 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, const int batch_size = this->layer_param_.image_data_param().batch_size(); if (crop_size > 0) { top[0]->Reshape(batch_size, channels, crop_size, crop_size); - this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape(batch_size, channels, + crop_size, crop_size); this->transformed_data_.Reshape(1, channels, crop_size, crop_size); } else { top[0]->Reshape(batch_size, channels, height, width); - this->prefetch_data_.Reshape(batch_size, channels, height, width); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape(batch_size, channels, height, width); this->transformed_data_.Reshape(1, channels, height, width); } LOG(INFO) << "output data size: " << top[0]->num() << "," @@ -82,7 +85,8 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, << top[0]->width(); // label top[1]->Reshape(batch_size, 1, 1, 1); - this->prefetch_label_.Reshape(batch_size, 1, 1, 1); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].label_.Reshape(batch_size, 1, 1, 1); } template @@ -92,15 +96,15 @@ void ImageDataLayer::ShuffleImages() { shuffle(lines_.begin(), lines_.end(), prefetch_rng); } -// This function is used to create a thread that prefetches the data. +// This function is called on prefetch thread template -void ImageDataLayer::InternalThreadEntry() { +void ImageDataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); ImageDataParameter image_data_param = this->layer_param_.image_data_param(); const int batch_size = image_data_param.batch_size(); @@ -114,14 +118,14 @@ void ImageDataLayer::InternalThreadEntry() { if (batch_size == 1 && crop_size == 0 && new_height == 0 && new_width == 0) { cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first, 0, 0, is_color); - this->prefetch_data_.Reshape(1, cv_img.channels(), + batch->data_.Reshape(1, cv_img.channels(), cv_img.rows, cv_img.cols); this->transformed_data_.Reshape(1, cv_img.channels(), cv_img.rows, cv_img.cols); } - Dtype* prefetch_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* prefetch_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* prefetch_data = batch->data_.mutable_cpu_data(); + Dtype* prefetch_label = batch->label_.mutable_cpu_data(); // datum scales const int lines_size = lines_.size(); @@ -135,7 +139,7 @@ void ImageDataLayer::InternalThreadEntry() { read_time += timer.MicroSeconds(); timer.Start(); // Apply transformations (mirror, crop...) to the image - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(prefetch_data + offset); this->data_transformer_->Transform(cv_img, &(this->transformed_data_)); trans_time += timer.MicroSeconds(); diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index 36e41560327..c8ce8750e14 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -27,7 +27,7 @@ namespace caffe { template WindowDataLayer::~WindowDataLayer() { - this->JoinPrefetchThread(); + CHECK(this->StopInternalThread()) << "Stop thread failed"; } template @@ -171,14 +171,17 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, CHECK_GT(crop_size, 0); const int batch_size = this->layer_param_.window_data_param().batch_size(); top[0]->Reshape(batch_size, channels, crop_size, crop_size); - this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape( + batch_size, channels, crop_size, crop_size); LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width(); // label top[1]->Reshape(batch_size, 1, 1, 1); - this->prefetch_label_.Reshape(batch_size, 1, 1, 1); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].label_.Reshape(batch_size, 1, 1, 1); // data mean has_mean_file_ = this->transform_param_.has_mean_file(); @@ -216,9 +219,9 @@ unsigned int WindowDataLayer::PrefetchRand() { return (*prefetch_rng)(); } -// Thread fetching the data +// This function is called on prefetch thread template -void WindowDataLayer::InternalThreadEntry() { +void WindowDataLayer::load_batch(Batch* batch) { // At each iteration, sample N windows where N*p are foreground (object) // windows and N*(1-p) are background (non-object) windows CPUTimer batch_timer; @@ -226,8 +229,8 @@ void WindowDataLayer::InternalThreadEntry() { double read_time = 0; double trans_time = 0; CPUTimer timer; - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); + Dtype* top_label = batch->label_.mutable_cpu_data(); const Dtype scale = this->layer_param_.window_data_param().scale(); const int batch_size = this->layer_param_.window_data_param().batch_size(); const int context_pad = this->layer_param_.window_data_param().context_pad(); @@ -251,7 +254,7 @@ void WindowDataLayer::InternalThreadEntry() { bool use_square = (crop_mode == "square") ? true : false; // zero out batch - caffe_set(this->prefetch_data_.count(), Dtype(0), top_data); + caffe_set(batch->data_.count(), Dtype(0), top_data); const int num_fg = static_cast(static_cast(batch_size) * fg_fraction); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 84b475ce3cd..9c47fc4761f 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -404,13 +404,14 @@ message ConvolutionParameter { } // Message that stores parameters used by DataLayer +// next available ID: 12 (last added: probability) message DataParameter { enum DB { LEVELDB = 0; LMDB = 1; } // Specify the data source. - optional string source = 1; + repeated string source = 1; // Specify the batch size. optional uint32 batch_size = 4; // The rand_skip variable is for the data layer to skip a few data points @@ -418,7 +419,7 @@ message DataParameter { // point would be set as rand_skip * rand(0,1). Note that rand_skip should not // be larger than the number of keys in the database. optional uint32 rand_skip = 7 [default = 0]; - optional DB backend = 8 [default = LEVELDB]; + repeated DB backend = 8; // DEPRECATED. See TransformationParameter. For data pre-processing, we can do // simple scaling and subtracting the data mean, if provided. Note that the // mean subtraction is always carried out before scaling. @@ -432,6 +433,14 @@ message DataParameter { optional bool mirror = 6 [default = false]; // Force the encoded image to have 3 color channels optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Number of batches to prefetch to host memory + // from each source, increase if read bandwidth has glitches). + optional uint32 prefetch = 10 [default = 4]; + // If multiple sources are given, a probability can be set on each source. + // Samples will be picked from it with given probability, and the label will + // be set to the source index. This allows experimenting with different + // class ratios at runtime without rebuilding datasets. + repeated float probability = 11; } // Message that stores parameters used by DropoutLayer diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index 7617ccfb27f..0da7a3bac79 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -108,6 +108,18 @@ void* SyncedMemory::mutable_gpu_data() { #endif } +#ifndef CPU_ONLY +void SyncedMemory::async_gpu_push(const cudaStream_t& stream) { + CHECK(head_ == HEAD_AT_CPU); + if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + } + const cudaMemcpyKind put = cudaMemcpyHostToDevice; + CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream)); + // Assume caller will synchronize on the stream before use + head_ = SYNCED; +} +#endif } // namespace caffe diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index afe2a40d227..3323cc6e302 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -24,14 +24,16 @@ class DataLayerTest : public MultiDeviceTest { protected: DataLayerTest() - : backend_(DataParameter_DB_LEVELDB), - blob_top_data_(new Blob()), + : blob_top_data_(new Blob()), blob_top_label_(new Blob()), seed_(1701) {} virtual void SetUp() { - filename_.reset(new string()); - MakeTempDir(filename_.get()); - *filename_ += "/db"; + for (int i = 0; i < BACKENDS_COUNT; ++i) { + backends_[i] = DataParameter_DB_LEVELDB, + filenames_[i].reset(new string()); + MakeTempDir(filenames_[i].get()); + *filenames_[i] += "/db"; + } blob_top_vec_.push_back(blob_top_data_); blob_top_vec_.push_back(blob_top_label_); } @@ -39,13 +41,13 @@ class DataLayerTest : public MultiDeviceTest { // Fill the DB with data: if unique_pixels, each pixel is unique but // all images are the same; else each image is unique but all pixels within // an image are the same. - void Fill(const bool unique_pixels, DataParameter_DB backend) { - backend_ = backend; - LOG(INFO) << "Using temporary dataset " << *filename_; + void Fill(const bool unique_pixels, DataParameter_DB backend, int index = 0) { + backends_[index] = backend; + LOG(INFO) << "Using temporary dataset " << *(filenames_[index]); scoped_ptr db(db::GetDB(backend)); - db->Open(*filename_, db::NEW); + db->Open(*(filenames_[index]), db::NEW); scoped_ptr txn(db->NewTransaction()); - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { Datum datum; datum.set_label(i); datum.set_channels(2); @@ -71,9 +73,9 @@ class DataLayerTest : public MultiDeviceTest { LayerParameter param; param.set_phase(TRAIN); DataParameter* data_param = param.mutable_data_param(); - data_param->set_batch_size(5); - data_param->set_source(filename_->c_str()); - data_param->set_backend(backend_); + data_param->set_batch_size(BATCH_SIZE); + data_param->add_source(filenames_[0]->c_str()); + data_param->add_backend(backends_[0]); TransformationParameter* transform_param = param.mutable_transform_param(); @@ -92,10 +94,10 @@ class DataLayerTest : public MultiDeviceTest { for (int iter = 0; iter < 100; ++iter) { layer.Forward(blob_bottom_vec_, blob_top_vec_); - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); } - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { for (int j = 0; j < 24; ++j) { EXPECT_EQ(scale * i, blob_top_data_->cpu_data()[i * 24 + j]) << "debug: iter " << iter << " i " << i << " j " << j; @@ -107,9 +109,9 @@ class DataLayerTest : public MultiDeviceTest { void TestReshape(DataParameter_DB backend) { const int num_inputs = 5; // Save data of varying shapes. - LOG(INFO) << "Using temporary dataset " << *filename_; + LOG(INFO) << "Using temporary dataset " << *filenames_[0]; scoped_ptr db(db::GetDB(backend)); - db->Open(*filename_, db::NEW); + db->Open(*filenames_[0], db::NEW); scoped_ptr txn(db->NewTransaction()); for (int i = 0; i < num_inputs; ++i) { Datum datum; @@ -136,8 +138,8 @@ class DataLayerTest : public MultiDeviceTest { param.set_phase(TEST); DataParameter* data_param = param.mutable_data_param(); data_param->set_batch_size(1); - data_param->set_source(filename_->c_str()); - data_param->set_backend(backend); + data_param->add_source(filenames_[0]->c_str()); + data_param->add_backend(backend); DataLayer layer(param); layer.SetUp(blob_bottom_vec_, blob_top_vec_); @@ -171,14 +173,15 @@ class DataLayerTest : public MultiDeviceTest { void TestReadCrop(Phase phase) { const Dtype scale = 3; + const int batch = BATCH_SIZE; LayerParameter param; param.set_phase(phase); Caffe::set_random_seed(1701); DataParameter* data_param = param.mutable_data_param(); - data_param->set_batch_size(5); - data_param->set_source(filename_->c_str()); - data_param->set_backend(backend_); + data_param->set_batch_size(batch); + data_param->add_source(filenames_[0]->c_str()); + data_param->add_backend(backends_[0]); TransformationParameter* transform_param = param.mutable_transform_param(); @@ -187,22 +190,22 @@ class DataLayerTest : public MultiDeviceTest { DataLayer layer(param); layer.SetUp(blob_bottom_vec_, blob_top_vec_); - EXPECT_EQ(blob_top_data_->num(), 5); + EXPECT_EQ(blob_top_data_->num(), batch); EXPECT_EQ(blob_top_data_->channels(), 2); EXPECT_EQ(blob_top_data_->height(), 1); EXPECT_EQ(blob_top_data_->width(), 1); - EXPECT_EQ(blob_top_label_->num(), 5); + EXPECT_EQ(blob_top_label_->num(), batch); EXPECT_EQ(blob_top_label_->channels(), 1); EXPECT_EQ(blob_top_label_->height(), 1); EXPECT_EQ(blob_top_label_->width(), 1); for (int iter = 0; iter < 2; ++iter) { layer.Forward(blob_bottom_vec_, blob_top_vec_); - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < batch; ++i) { EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); } int num_with_center_value = 0; - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < batch; ++i) { for (int j = 0; j < 2; ++j) { const Dtype center_value = scale * (j ? 17 : 5); num_with_center_value += @@ -227,9 +230,9 @@ class DataLayerTest : public MultiDeviceTest { LayerParameter param; param.set_phase(TRAIN); DataParameter* data_param = param.mutable_data_param(); - data_param->set_batch_size(5); - data_param->set_source(filename_->c_str()); - data_param->set_backend(backend_); + data_param->set_batch_size(BATCH_SIZE); + data_param->add_source(filenames_[0]->c_str()); + data_param->add_backend(backends_[0]); TransformationParameter* transform_param = param.mutable_transform_param(); @@ -244,11 +247,11 @@ class DataLayerTest : public MultiDeviceTest { layer1.SetUp(blob_bottom_vec_, blob_top_vec_); for (int iter = 0; iter < 2; ++iter) { layer1.Forward(blob_bottom_vec_, blob_top_vec_); - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); } vector iter_crop_sequence; - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { for (int j = 0; j < 2; ++j) { iter_crop_sequence.push_back( blob_top_data_->cpu_data()[i * 2 + j]); @@ -265,10 +268,10 @@ class DataLayerTest : public MultiDeviceTest { layer2.SetUp(blob_bottom_vec_, blob_top_vec_); for (int iter = 0; iter < 2; ++iter) { layer2.Forward(blob_bottom_vec_, blob_top_vec_); - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); } - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { for (int j = 0; j < 2; ++j) { EXPECT_EQ(crop_sequence[iter][i * 2 + j], blob_top_data_->cpu_data()[i * 2 + j]) @@ -282,9 +285,9 @@ class DataLayerTest : public MultiDeviceTest { LayerParameter param; param.set_phase(TRAIN); DataParameter* data_param = param.mutable_data_param(); - data_param->set_batch_size(5); - data_param->set_source(filename_->c_str()); - data_param->set_backend(backend_); + data_param->set_batch_size(BATCH_SIZE); + data_param->add_source(filenames_[0]->c_str()); + data_param->add_backend(backends_[0]); TransformationParameter* transform_param = param.mutable_transform_param(); @@ -300,11 +303,11 @@ class DataLayerTest : public MultiDeviceTest { layer1.SetUp(blob_bottom_vec_, blob_top_vec_); for (int iter = 0; iter < 2; ++iter) { layer1.Forward(blob_bottom_vec_, blob_top_vec_); - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); } vector iter_crop_sequence; - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { for (int j = 0; j < 2; ++j) { iter_crop_sequence.push_back( blob_top_data_->cpu_data()[i * 2 + j]); @@ -321,11 +324,11 @@ class DataLayerTest : public MultiDeviceTest { layer2.SetUp(blob_bottom_vec_, blob_top_vec_); for (int iter = 0; iter < 2; ++iter) { layer2.Forward(blob_bottom_vec_, blob_top_vec_); - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { EXPECT_EQ(i, blob_top_label_->cpu_data()[i]); } int num_sequence_matches = 0; - for (int i = 0; i < 5; ++i) { + for (int i = 0; i < BATCH_SIZE; ++i) { for (int j = 0; j < 2; ++j) { num_sequence_matches += (crop_sequence[iter][i * 2 + j] == blob_top_data_->cpu_data()[i * 2 + j]); @@ -335,10 +338,90 @@ class DataLayerTest : public MultiDeviceTest { } } + void TestProbabilities() { + LayerParameter param; + DataParameter* data_param = param.mutable_data_param(); + data_param->set_batch_size(BATCH_SIZE); + for (int i = 0; i < BACKENDS_COUNT; ++i) { + data_param->add_source(filenames_[i]->c_str()); + data_param->add_backend( + i % 2 ? DataParameter_DB_LEVELDB : DataParameter_DB_LMDB); + data_param->add_probability(0); + } + + Caffe::set_random_seed(544432); + int counts[BACKENDS_COUNT]; + + // Balanced two + data_param->set_probability(0, .5f); + data_param->set_probability(1, .5f); + caffe_memset(sizeof(counts), 0, counts); + probabilities_run(param, counts); + EXPECT_EQ(58, counts[0]); + EXPECT_EQ(57, counts[1]); + EXPECT_EQ(0, counts[2]); + + // Balanced three + data_param->set_probability(0, .33333333f); + data_param->set_probability(1, .33333333f); + data_param->set_probability(2, .33333333f); + caffe_memset(sizeof(counts), 0, counts); + probabilities_run(param, counts); + EXPECT_EQ(43, counts[0]); + EXPECT_EQ(27, counts[1]); + EXPECT_EQ(45, counts[2]); + + // Only one + data_param->set_probability(0, 0); + data_param->set_probability(1, 0); + data_param->set_probability(2, 1); + caffe_memset(sizeof(counts), 0, counts); + probabilities_run(param, counts); + EXPECT_EQ(0, counts[0]); + EXPECT_EQ(0, counts[1]); + EXPECT_EQ(115, counts[2]); + } + + void probabilities_run(LayerParameter param, int* counts) { + const int batch = BATCH_SIZE; + DataLayer layer(param); + layer.SetUp(blob_bottom_vec_, blob_top_vec_); + EXPECT_EQ(blob_top_data_->num(), batch); + EXPECT_EQ(blob_top_data_->channels(), 2); + EXPECT_EQ(blob_top_data_->height(), 3); + EXPECT_EQ(blob_top_data_->width(), 4); + EXPECT_EQ(blob_top_label_->num(), batch); + EXPECT_EQ(blob_top_label_->channels(), 1); + EXPECT_EQ(blob_top_label_->height(), 1); + EXPECT_EQ(blob_top_label_->width(), 1); + + const int examples = 100; + for (int iter = 0; iter < examples / batch; ++iter) { + layer.Forward(blob_bottom_vec_, blob_top_vec_); + } + + for (;;) { + int total = 0; + for (int i = 0; i < BACKENDS_COUNT; ++i) { + DataLoader loader(param.data_param(), i); + counts[i] = loader.full().pops(); + total += counts[i]; + } + // Wait until prefetch queue refills, for reproducibility + const int prefetch = BasePrefetchingDataLayer::PREFETCH_COUNT; + if (total == examples + prefetch * batch) { + break; + } + usleep(1000); + } + } + virtual ~DataLayerTest() { delete blob_top_data_; delete blob_top_label_; } - DataParameter_DB backend_; - shared_ptr filename_; + static const int BATCH_SIZE = 5; + static const int BACKENDS_COUNT = 3; + DataParameter_DB backends_[BACKENDS_COUNT]; + shared_ptr filenames_[BACKENDS_COUNT]; Blob* const blob_top_data_; Blob* const blob_top_label_; vector*> blob_bottom_vec_; @@ -424,4 +507,12 @@ TYPED_TEST(DataLayerTest, TestReadCropTestLMDB) { this->TestReadCrop(TEST); } +TYPED_TEST(DataLayerTest, TestTwoProbabilities) { + const bool unique_pixels = true; // all images the same; pixels different + this->Fill(unique_pixels, DataParameter_DB_LMDB, 0); + this->Fill(unique_pixels, DataParameter_DB_LEVELDB, 1); + this->Fill(unique_pixels, DataParameter_DB_LMDB, 2); + this->TestProbabilities(); +} + } // namespace caffe diff --git a/src/caffe/test/test_internal_thread.cpp b/src/caffe/test/test_internal_thread.cpp index 31882b6db1d..b495768a530 100644 --- a/src/caffe/test/test_internal_thread.cpp +++ b/src/caffe/test/test_internal_thread.cpp @@ -15,7 +15,7 @@ TEST_F(InternalThreadTest, TestStartAndExit) { EXPECT_FALSE(thread.is_started()); EXPECT_TRUE(thread.StartInternalThread()); EXPECT_TRUE(thread.is_started()); - EXPECT_TRUE(thread.WaitForInternalThreadToExit()); + EXPECT_TRUE(thread.StopInternalThread()); EXPECT_FALSE(thread.is_started()); } diff --git a/src/caffe/util/blocking_queue.cpp b/src/caffe/util/blocking_queue.cpp new file mode 100644 index 00000000000..db4c983e2b0 --- /dev/null +++ b/src/caffe/util/blocking_queue.cpp @@ -0,0 +1,87 @@ +#include +#include + +#include "caffe/data_layers.hpp" +#include "caffe/util/blocking_queue.hpp" + +namespace caffe { + +template +class blocking_queue::sync { + public: + mutable boost::mutex mutex_; + boost::condition_variable condition_; +}; + +template +blocking_queue::blocking_queue() + : sync_(new sync()), + last_wait_log_(time(0)), + pops_() { +} + +template +blocking_queue::~blocking_queue() { +} + +template +void blocking_queue::push(const T& t) { + boost::mutex::scoped_lock lock(sync_.get()->mutex_); + queue_.push(t); + lock.unlock(); + sync_.get()->condition_.notify_one(); +} + +template +bool blocking_queue::empty() const { + boost::mutex::scoped_lock lock(sync_.get()->mutex_); + return queue_.empty(); +} +template +bool blocking_queue::try_pop(T* t) { + boost::mutex::scoped_lock lock(sync_.get()->mutex_); + + if (queue_.empty()) + return false; + + *t = queue_.front(); + queue_.pop(); + return true; +} + +template +T blocking_queue::pop(const string& log_on_wait) { + boost::mutex::scoped_lock lock(sync_.get()->mutex_); + + while (queue_.empty()) { + if (!log_on_wait.empty()) { + time_t now = time(0); + if (now - last_wait_log_ > 5) { + last_wait_log_ = now; + LOG(INFO)<< log_on_wait; + } + } + sync_.get()->condition_.wait(lock); + } + + T t = queue_.front(); + queue_.pop(); + pops_++; + return t; +} + +template +T blocking_queue::peek() { + boost::mutex::scoped_lock lock(sync_.get()->mutex_); + + while (queue_.empty()) + sync_.get()->condition_.wait(lock); + + return queue_.front(); +} + +template class blocking_queue*>; +template class blocking_queue*>; +template class blocking_queue; + +} // namespace caffe diff --git a/src/caffe/util/db.cpp b/src/caffe/util/db.cpp index 7f7018107ec..fa8efa4a377 100644 --- a/src/caffe/util/db.cpp +++ b/src/caffe/util/db.cpp @@ -28,7 +28,17 @@ void LMDB::Open(const string& source, Mode mode) { } int flags = 0; if (mode == READ) { - flags = MDB_RDONLY | MDB_NOTLS; + // No locking, assume DB is not written to at the same time, otherwise + // LMDB tries to lock the file, which fails if filesystem is read-only + flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; + } + // Allow DB to be stand-alone file + { + struct stat st_buf; + stat(source.c_str(), &st_buf); + if (S_ISREG(st_buf.st_mode)) { + flags |= MDB_NOSUBDIR; + } } MDB_CHECK(mdb_env_open(mdb_env_, source.c_str(), flags, 0664)); LOG(INFO) << "Opened lmdb " << source; diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 38a06026adf..38f350e6b96 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -303,7 +303,7 @@ bool UpgradeV0LayerParameter(const V1LayerParameter& v0_layer_connection, } if (v0_layer_param.has_source()) { if (type == "data") { - layer_param->mutable_data_param()->set_source(v0_layer_param.source()); + layer_param->mutable_data_param()->add_source(v0_layer_param.source()); } else if (type == "hdf5_data") { layer_param->mutable_hdf5_data_param()->set_source( v0_layer_param.source());