Skip to content

Commit

Permalink
SINGA-57 Improve Distributed Hogwild
Browse files Browse the repository at this point in the history
The ClusterProto::sync_freq field controls the frequency of sync between
server groups.
After updating of Param (slice), the server checks the num of updates
since last sync. It also checks the num of pending syncs (i.e., requests
haven't received reponses) to avoid sending too many msgs to stopped
servers (the msgs would be occupy the memory of the sending buffer)
The server respones to every sync requests with the latest Param values.

Note: current does not support (there is bug) multiple worker groups in
one process for the distributed hogwild framework. We recommend to
replace this cluster topology with in-memory hogwild, i.e., launching
one worker group with multiple workers and one server group.
  • Loading branch information
nudles committed Sep 15, 2015
1 parent d5d817e commit ed9e373
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 220 deletions.
2 changes: 1 addition & 1 deletion examples/cifar10/job.conf
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: "cifar10-convnet"
train_steps: 1000
test_steps: 100
test_freq:300
disp_freq:30
disp_freq: 30
train_one_batch {
alg: kBP
}
Expand Down
62 changes: 48 additions & 14 deletions include/trainer/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class Server{
Server(int thread_id, int group_id, int server_id);
virtual ~Server();
void Setup(const UpdaterProto& proto,
std::unordered_map<int, ParamEntry*>* shard,
const std::vector<int>& slice2group);
const std::vector<int>& slice2group,
const std::vector<int>& slice2server);
void Run();
const int grp_id() const {
return grp_id_;
Expand All @@ -38,44 +38,78 @@ class Server{
/**
* Process GET request.
*
* @return the orignal message or response message
* @return the orignal message or a response message which contains the values
* of the Param with the request version.
*/
virtual Msg* HandleGet(Msg** msg);

/**
* Process Update request.
*
* It waits until received the gradients from all workers from the same worker
* group. After updating, it responses to each sender with the new Param
* values. It may generate a sync message to the server group that maintains
* the global version of the updated Param (slice).
*
* Note: there is no counter for each worker group on the number of received
* update requests. Hence it is possible that the server would conduct the
* update when it receives x requests from group a and y requests from group
* b where x + y = group size. To avoid this problem, we can
* 1. maintain request list for each group for each Param at the server side
* 2. do not span a worker group among multiple nodes. then the updates from
* the same group would be locally aggregated on the worker node. And the
* server would conduct the update immediately after receiving the aggregated
* request.
* 3. launch only one worker group.
*
* @return the orignal message or response message
*/
const std::vector<Msg*> HandleUpdate(Msg **msg);

/**
* Process PUT request.
*
* @return the original message or response message. If we don't want need to
* @return the original message or response message. If we don't want to
* acknowledge the put request, then return nullptr.
*/
virtual Msg* HandlePut(Msg **msg);

/**
* TODO Process SYNC request.
*/
* Handle sync request from other server groups.
*
* It adds updates of Param (slice) from other server groups directly to
* local Param (slice). Currently, each Param (slice) has a master group,
* i.e., slice2group_[sliceid], which would receive such requests from all
* other server groups for the Param object.
*
* @param msg request msg containing the parameter updates
* @return response msg that contains the fresh parameter values.
*/
virtual Msg* HandleSyncRequest(Msg** msg);

/**
* Generate sync message which sends local mastered Param slice to other
* server groups
* @param param slice to be sync with others
* @return sync messages
* Handle sync response.
*
* The response msg includes the latest values of a Param object, for which
* this server sent the sync request to the master/maintainer group.
* The local Param values are replaced with the addition result of local
* udpates since the sync request was sent and the received Param values.
*
* @param response message
*/
const std::vector<Msg*> GenSyncMsgs(Param* param);
void HandleSyncResponse(Msg** msg);

protected:
int thread_id_,grp_id_, id_;
Updater* updater_;
std::unordered_map<int, ParamEntry*> *shard_;
std::vector<int> slice2group_;
std::unordered_map<int, std::shared_ptr<Blob<float>>> last_data_;
//!< map from slice ID to slice and deleted in the destructor
std::unordered_map<int, ParamEntry*> shard_;
std::vector<int> slice2group_, slice2server_;
//!< num of updates from last sync with master server group for a param/slice
std::vector<int> nUpdates_;
//!< num of sync requests that have not been responded
std::vector<int> nPendingSync_;
std::vector<Blob<float>> last_sync_;
std::unordered_map<int, std::vector<Msg*>> buffer_requests_;
};
} /* Server */
Expand Down
10 changes: 0 additions & 10 deletions include/trainer/trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,6 @@ class Trainer{
const vector<Server*>& servers);

void Run(const vector<Worker*>& workers, const vector<Server*>& servers);
/**
* Generate msg to trigger synchronization with other server groups.
*
* @param server the local server index whom the message is sent to
* @param servers all local servers
* @return sync msg
*/
Msg* GenSyncReminderMsg(int server, const vector<Server*>& servers);
/**
* Display metrics to log (standard output)
*/
Expand Down Expand Up @@ -143,8 +135,6 @@ class Trainer{
int procs_id_;
Router *router_;
std::unordered_map<int, ParamEntry*> worker_shard_;
//!< map from slice ID to slice, used by servers and deleted in the destructor
std::unordered_map<int, ParamEntry*> server_shard_;
//!< map from slice to the server that updates it
vector<int> slice2server_;
};
Expand Down
16 changes: 11 additions & 5 deletions include/utils/cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,8 @@ class Cluster {
const int worker_timeout() const { return cluster_.worker_timeout(); }
const int server_timeout() const { return cluster_.server_timeout(); }
*/
inline bool server_update() const { return cluster_.server_update(); }
inline bool share_memory() const { return cluster_.share_memory(); }
/**
* bandwidth Bytes/s
*/
inline int bandwidth() const { return cluster_.bandwidth(); }
inline int sync_freq() const { return cluster_.sync_freq(); }
inline int poll_time() const { return cluster_.poll_time(); }
ClusterRuntime* runtime() const { return cluster_rt_; }

Expand All @@ -106,6 +102,16 @@ class Cluster {
return procs_ids_.at(Hash(group_id, id, flag));
}
inline std::string hostip() const { return hostip_; }

/**
* @param pid, processs ID
* @param group_size, num of executors in a group
* @param procs_size, num of executors in a procs
*
* @return a vector with 4 integers:
* [group start, group end), [start executor, end executor)
*/
const std::vector<int> ExecutorRng(int pid, int group_size, int procs_size);
/**
* Register this process.
*
Expand Down
1 change: 0 additions & 1 deletion src/proto/common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ enum MsgType {
kRUpdate = 9;
kConnect = 10;
kMetric = 11;
kSyncReminder = 12;
};

enum EntityType {
Expand Down
7 changes: 3 additions & 4 deletions src/proto/job.proto
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,14 @@ message ClusterProto {
// servers and workers in different processes?
optional bool server_worker_separate = 20 [default = false];

// sync frequency between server groups
optional int32 sync_freq = 21 [default = 1];

// port number used by ZeroMQ
optional int32 start_port = 60 [default = 6723];
// conduct updates at server side; otherwise do it at worker side
optional bool server_update = 61 [default = true];
// share memory space between worker groups in one procs
optional bool share_memory = 62 [default = true];

// bandwidth of ethernet, Bytes per second, default is 1 Gbps
optional int32 bandwidth = 80 [default = 134217728];
// poll time in milliseconds
optional int32 poll_time = 81 [default = 100];
}
Expand Down
Loading

0 comments on commit ed9e373

Please sign in to comment.