Skip to content

Commit

Permalink
Creating WorkerSolver instead of Solver instances
Browse files Browse the repository at this point in the history
  • Loading branch information
cypof committed Jun 29, 2015
1 parent b554119 commit 9b10508
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
35 changes: 25 additions & 10 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace caffe {
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ComputeUpdateValue to compute a parameter update
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
Expand Down Expand Up @@ -55,8 +55,8 @@ class Solver {
}

protected:
// Get and apply the update value for the current iteration.
virtual void Iteration() {}
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
Expand All @@ -65,12 +65,8 @@ class Solver {
// The test routine
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(SolverState* state) {
CHECK(false) << "Should be overriden";
}
virtual void RestoreSolverState(const SolverState& state) {
CHECK(false) << "Should be overriden";
}
virtual void SnapshotSolverState(SolverState* state) = 0;
virtual void RestoreSolverState(const SolverState& state) = 0;
void DisplayOutputBlobs(const int net_id);

SolverParameter param_;
Expand All @@ -86,6 +82,25 @@ class Solver {
DISABLE_COPY_AND_ASSIGN(Solver);
};

/**
* @brief Solver that only computes gradients, used as worker
* for multi-GPU training.
*/
template <typename Dtype>
class WorkerSolver : public Solver<Dtype> {
public:
explicit WorkerSolver(const SolverParameter& param)
: Solver<Dtype>(param) {}

protected:
void ApplyUpdate() {}
void SnapshotSolverState(SolverState* state) {
CHECK(false) << "Should not be called on worker";
}
void RestoreSolverState(const SolverState& state) {
CHECK(false) << "Should not be called on worker";
}
};

/**
* @brief Optimizes the parameters of a Net using
Expand All @@ -104,7 +119,7 @@ class SGDSolver : public Solver<Dtype> {
protected:
void PreSolve();
Dtype GetLearningRate();
virtual void Iteration();
virtual void ApplyUpdate();
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void ClipGradients();
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ P2PSync<Dtype>::P2PSync(shared_ptr<Solver<Dtype> > root_solver,
solver_ = root_solver;
} else {
Caffe::set_root_solver(false);
solver_.reset(new Solver<Dtype>(param));
solver_.reset(new WorkerSolver<Dtype>(param));
Caffe::set_root_solver(true);
}
this->configure(solver_.get());
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ void Solver<Dtype>::Step(int iters) {
callbacks_[i]->on_gradients_ready(&timer, &timing);
}
timer.Start();
Iteration();
ApplyUpdate();
timing << " apply: " << timer.MilliSeconds();

#ifdef BENCHMARK_SOLVER
Expand Down Expand Up @@ -502,7 +502,7 @@ void SGDSolver<Dtype>::ClipGradients() {
}

template <typename Dtype>
void SGDSolver<Dtype>::Iteration() {
void SGDSolver<Dtype>::ApplyUpdate() {
CHECK(Caffe::root_solver());
Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
Expand Down

0 comments on commit 9b10508

Please sign in to comment.