-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
- Makes sure each solver accesses a different subset of the data - Sequential reading of DB for performance - Prefetches a configurable amount of data to host memory - Distributes data to solvers in round-robin way for determinism
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#ifndef CAFFE_DATA_READER_HPP_ | ||
#define CAFFE_DATA_READER_HPP_ | ||
|
||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "caffe/common.hpp" | ||
#include "caffe/internal_thread.hpp" | ||
#include "caffe/util/blocking_queue.hpp" | ||
#include "caffe/util/db.hpp" | ||
|
||
namespace caffe { | ||
|
||
/** | ||
* @brief Reads data from a source to queues available to data layers. | ||
* A single reading thread is created per source, even if multiple solvers | ||
* are running in parallel, e.g. for multi-GPU training. This makes sure | ||
* databases are read sequentially, and that each solver accesses a different | ||
* subset of the database. Data is distributed to solvers in a round-robin | ||
* way to keep parallel training deterministic. | ||
*/ | ||
class DataReader { | ||
public: | ||
explicit DataReader(const LayerParameter& param); | ||
~DataReader(); | ||
|
||
inline BlockingQueue<Datum*>& free() const { | ||
return queue_pair_->free_; | ||
} | ||
inline BlockingQueue<Datum*>& full() const { | ||
return queue_pair_->full_; | ||
} | ||
|
||
protected: | ||
// Queue pairs are shared between a body and its readers | ||
This comment has been minimized.
Sorry, something went wrong. |
||
class QueuePair { | ||
public: | ||
explicit QueuePair(int size); | ||
~QueuePair(); | ||
|
||
BlockingQueue<Datum*> free_; | ||
BlockingQueue<Datum*> full_; | ||
This comment has been minimized.
Sorry, something went wrong.
cdoersch
|
||
|
||
DISABLE_COPY_AND_ASSIGN(QueuePair); | ||
}; | ||
|
||
// A single body is created per source | ||
This comment has been minimized.
Sorry, something went wrong.
cdoersch
|
||
class Body : public InternalThread { | ||
public: | ||
explicit Body(const LayerParameter& param); | ||
virtual ~Body(); | ||
|
||
protected: | ||
void InternalThreadEntry(); | ||
void read_one(db::Cursor* cursor, QueuePair* qp); | ||
|
||
const LayerParameter param_; | ||
BlockingQueue<shared_ptr<QueuePair> > new_queue_pairs_; | ||
|
||
friend class DataReader; | ||
|
||
DISABLE_COPY_AND_ASSIGN(Body); | ||
}; | ||
|
||
// A source is uniquely identified by its layer name + path, in case | ||
// the same database is read from two different locations in the net. | ||
static inline string source_key(const LayerParameter& param) { | ||
return param.name() + ":" + param.data_param().source(); | ||
} | ||
|
||
const shared_ptr<QueuePair> queue_pair_; | ||
shared_ptr<Body> body_; | ||
|
||
static map<const string, boost::weak_ptr<DataReader::Body> > bodies_; | ||
|
||
DISABLE_COPY_AND_ASSIGN(DataReader); | ||
}; | ||
|
||
} // namespace caffe | ||
|
||
#endif // CAFFE_DATA_READER_HPP_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#include <boost/thread.hpp> | ||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "caffe/common.hpp" | ||
#include "caffe/data_layers.hpp" | ||
#include "caffe/data_reader.hpp" | ||
#include "caffe/proto/caffe.pb.h" | ||
|
||
namespace caffe { | ||
|
||
using boost::weak_ptr; | ||
|
||
map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_; | ||
static boost::mutex bodies_mutex_; | ||
|
||
DataReader::DataReader(const LayerParameter& param) | ||
: queue_pair_(new QueuePair( // | ||
param.data_param().prefetch() * param.data_param().batch_size())) { | ||
// Get or create a body | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
cypof
Author
Owner
|
||
boost::mutex::scoped_lock lock(bodies_mutex_); | ||
string key = source_key(param); | ||
weak_ptr<Body>& weak = bodies_[key]; | ||
body_ = weak.lock(); | ||
if (!body_) { | ||
body_.reset(new Body(param)); | ||
bodies_[key] = weak_ptr<Body>(body_); | ||
} | ||
body_->new_queue_pairs_.push(queue_pair_); | ||
} | ||
|
||
DataReader::~DataReader() { | ||
string key = source_key(body_->param_); | ||
body_.reset(); | ||
boost::mutex::scoped_lock lock(bodies_mutex_); | ||
if (bodies_[key].expired()) { | ||
bodies_.erase(key); | ||
} | ||
} | ||
|
||
// | ||
|
||
DataReader::QueuePair::QueuePair(int size) { | ||
// Initialize the free queue with requested number of datums | ||
for (int i = 0; i < size; ++i) { | ||
free_.push(new Datum()); | ||
} | ||
} | ||
|
||
DataReader::QueuePair::~QueuePair() { | ||
Datum* datum; | ||
while (free_.try_pop(&datum)) { | ||
delete datum; | ||
} | ||
while (full_.try_pop(&datum)) { | ||
delete datum; | ||
} | ||
} | ||
|
||
// | ||
|
||
DataReader::Body::Body(const LayerParameter& param) | ||
: param_(param), | ||
new_queue_pairs_() { | ||
StartInternalThread(); | ||
} | ||
|
||
DataReader::Body::~Body() { | ||
StopInternalThread(); | ||
} | ||
|
||
void DataReader::Body::InternalThreadEntry() { | ||
shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend())); | ||
db->Open(param_.data_param().source(), db::READ); | ||
shared_ptr<db::Cursor> cursor(db->NewCursor()); | ||
vector<shared_ptr<QueuePair> > qps; | ||
try { | ||
// int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1; | ||
// TODO single solver until multi-gpu merge | ||
int solver_count = 1; | ||
This comment has been minimized.
Sorry, something went wrong.
cdoersch
|
||
|
||
// To ensure deterministic runs, only start running once all solvers | ||
// are ready. But solvers need to peek on one item during initialization, | ||
// so read one item, then wait for the next solver. | ||
for (int i = 0; i < solver_count; ++i) { | ||
shared_ptr<QueuePair> qp(new_queue_pairs_.pop()); | ||
read_one(cursor.get(), qp.get()); | ||
qps.push_back(qp); | ||
} | ||
// Main loop | ||
while (!must_stop()) { | ||
for (int i = 0; i < solver_count; ++i) { | ||
read_one(cursor.get(), qps[i].get()); | ||
} | ||
// Check no additional readers have been created. This can happen if | ||
// more than one net is trained at a time per process, whether single | ||
// or multi solver. It might also happen if two data layers have same | ||
// name and same source. | ||
CHECK_EQ(new_queue_pairs_.size(), 0); | ||
} | ||
} catch (boost::thread_interrupted&) { | ||
// Interrupted exception is expected on shutdown | ||
} | ||
} | ||
|
||
void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) { | ||
Datum* datum = qp->free_.pop(); | ||
// TODO deserialize in-place instead of copy? | ||
datum->ParseFromString(cursor->value()); | ||
qp->full_.push(datum); | ||
|
||
// go to the next iter | ||
cursor->Next(); | ||
if (!cursor->valid()) { | ||
DLOG(INFO) << "Restarting data prefetching from start."; | ||
cursor->SeekToFirst(); | ||
} | ||
} | ||
|
||
} // namespace caffe |
The variable/class names in this file and the associated comments are pretty uninformative. I can't think of a better name than QueuePair (though it's pretty uninformative). However, 'body' could mean many different things. 'DataSource' would be better for body, but it would be good to have something even more descriptive if you can think of one.