From ed9e37369c69dd76078e8285bc33d6b04ba60e9f Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Tue, 15 Sep 2015 15:51:28 +0800 Subject: [PATCH] SINGA-57 Improve Distributed Hogwild The ClusterProto::sync_freq field controls the frequency of sync between server groups. After updating of Param (slice), the server checks the num of updates since last sync. It also checks the num of pending syncs (i.e., requests haven't received reponses) to avoid sending too many msgs to stopped servers (the msgs would be occupy the memory of the sending buffer) The server respones to every sync requests with the latest Param values. Note: current does not support (there is bug) multiple worker groups in one process for the distributed hogwild framework. We recommend to replace this cluster topology with in-memory hogwild, i.e., launching one worker group with multiple workers and one server group. --- examples/cifar10/job.conf | 2 +- include/trainer/server.h | 62 ++++++++++---- include/trainer/trainer.h | 10 --- include/utils/cluster.h | 16 ++-- src/proto/common.proto | 1 - src/proto/job.proto | 7 +- src/trainer/server.cc | 167 +++++++++++++++++--------------------- src/trainer/trainer.cc | 124 +++++++--------------------- src/utils/cluster.cc | 21 +++++ 9 files changed, 190 insertions(+), 220 deletions(-) diff --git a/examples/cifar10/job.conf b/examples/cifar10/job.conf index b36c45afa1..343d9698cf 100644 --- a/examples/cifar10/job.conf +++ b/examples/cifar10/job.conf @@ -2,7 +2,7 @@ name: "cifar10-convnet" train_steps: 1000 test_steps: 100 test_freq:300 -disp_freq:30 +disp_freq: 30 train_one_batch { alg: kBP } diff --git a/include/trainer/server.h b/include/trainer/server.h index 8cc37c5fa7..869d10abfc 100644 --- a/include/trainer/server.h +++ b/include/trainer/server.h @@ -23,8 +23,8 @@ class Server{ Server(int thread_id, int group_id, int server_id); virtual ~Server(); void Setup(const UpdaterProto& proto, - std::unordered_map* shard, - const std::vector& slice2group); + const std::vector& slice2group, + const std::vector& slice2server); void Run(); const int grp_id() const { return grp_id_; @@ -38,13 +38,30 @@ class Server{ /** * Process GET request. * - * @return the orignal message or response message + * @return the orignal message or a response message which contains the values + * of the Param with the request version. */ virtual Msg* HandleGet(Msg** msg); /** * Process Update request. * + * It waits until received the gradients from all workers from the same worker + * group. After updating, it responses to each sender with the new Param + * values. It may generate a sync message to the server group that maintains + * the global version of the updated Param (slice). + * + * Note: there is no counter for each worker group on the number of received + * update requests. Hence it is possible that the server would conduct the + * update when it receives x requests from group a and y requests from group + * b where x + y = group size. To avoid this problem, we can + * 1. maintain request list for each group for each Param at the server side + * 2. do not span a worker group among multiple nodes. then the updates from + * the same group would be locally aggregated on the worker node. And the + * server would conduct the update immediately after receiving the aggregated + * request. + * 3. launch only one worker group. + * * @return the orignal message or response message */ const std::vector HandleUpdate(Msg **msg); @@ -52,30 +69,47 @@ class Server{ /** * Process PUT request. * - * @return the original message or response message. If we don't want need to + * @return the original message or response message. If we don't want to * acknowledge the put request, then return nullptr. */ virtual Msg* HandlePut(Msg **msg); /** - * TODO Process SYNC request. - */ + * Handle sync request from other server groups. + * + * It adds updates of Param (slice) from other server groups directly to + * local Param (slice). Currently, each Param (slice) has a master group, + * i.e., slice2group_[sliceid], which would receive such requests from all + * other server groups for the Param object. + * + * @param msg request msg containing the parameter updates + * @return response msg that contains the fresh parameter values. + */ virtual Msg* HandleSyncRequest(Msg** msg); /** - * Generate sync message which sends local mastered Param slice to other - * server groups - * @param param slice to be sync with others - * @return sync messages + * Handle sync response. + * + * The response msg includes the latest values of a Param object, for which + * this server sent the sync request to the master/maintainer group. + * The local Param values are replaced with the addition result of local + * udpates since the sync request was sent and the received Param values. + * + * @param response message */ - const std::vector GenSyncMsgs(Param* param); + void HandleSyncResponse(Msg** msg); protected: int thread_id_,grp_id_, id_; Updater* updater_; - std::unordered_map *shard_; - std::vector slice2group_; - std::unordered_map>> last_data_; + //!< map from slice ID to slice and deleted in the destructor + std::unordered_map shard_; + std::vector slice2group_, slice2server_; + //!< num of updates from last sync with master server group for a param/slice + std::vector nUpdates_; + //!< num of sync requests that have not been responded + std::vector nPendingSync_; + std::vector> last_sync_; std::unordered_map> buffer_requests_; }; } /* Server */ diff --git a/include/trainer/trainer.h b/include/trainer/trainer.h index 8be526941a..ed50705adf 100644 --- a/include/trainer/trainer.h +++ b/include/trainer/trainer.h @@ -79,14 +79,6 @@ class Trainer{ const vector& servers); void Run(const vector& workers, const vector& servers); - /** - * Generate msg to trigger synchronization with other server groups. - * - * @param server the local server index whom the message is sent to - * @param servers all local servers - * @return sync msg - */ - Msg* GenSyncReminderMsg(int server, const vector& servers); /** * Display metrics to log (standard output) */ @@ -143,8 +135,6 @@ class Trainer{ int procs_id_; Router *router_; std::unordered_map worker_shard_; - //!< map from slice ID to slice, used by servers and deleted in the destructor - std::unordered_map server_shard_; //!< map from slice to the server that updates it vector slice2server_; }; diff --git a/include/utils/cluster.h b/include/utils/cluster.h index be0e0dedef..73474af3f5 100644 --- a/include/utils/cluster.h +++ b/include/utils/cluster.h @@ -90,12 +90,8 @@ class Cluster { const int worker_timeout() const { return cluster_.worker_timeout(); } const int server_timeout() const { return cluster_.server_timeout(); } */ - inline bool server_update() const { return cluster_.server_update(); } inline bool share_memory() const { return cluster_.share_memory(); } - /** - * bandwidth Bytes/s - */ - inline int bandwidth() const { return cluster_.bandwidth(); } + inline int sync_freq() const { return cluster_.sync_freq(); } inline int poll_time() const { return cluster_.poll_time(); } ClusterRuntime* runtime() const { return cluster_rt_; } @@ -106,6 +102,16 @@ class Cluster { return procs_ids_.at(Hash(group_id, id, flag)); } inline std::string hostip() const { return hostip_; } + + /** + * @param pid, processs ID + * @param group_size, num of executors in a group + * @param procs_size, num of executors in a procs + * + * @return a vector with 4 integers: + * [group start, group end), [start executor, end executor) + */ + const std::vector ExecutorRng(int pid, int group_size, int procs_size); /** * Register this process. * diff --git a/src/proto/common.proto b/src/proto/common.proto index 671d6ad6a2..e166299c22 100644 --- a/src/proto/common.proto +++ b/src/proto/common.proto @@ -13,7 +13,6 @@ enum MsgType { kRUpdate = 9; kConnect = 10; kMetric = 11; - kSyncReminder = 12; }; enum EntityType { diff --git a/src/proto/job.proto b/src/proto/job.proto index 7861eaef90..80998e1211 100644 --- a/src/proto/job.proto +++ b/src/proto/job.proto @@ -129,15 +129,14 @@ message ClusterProto { // servers and workers in different processes? optional bool server_worker_separate = 20 [default = false]; + // sync frequency between server groups + optional int32 sync_freq = 21 [default = 1]; + // port number used by ZeroMQ optional int32 start_port = 60 [default = 6723]; - // conduct updates at server side; otherwise do it at worker side - optional bool server_update = 61 [default = true]; // share memory space between worker groups in one procs optional bool share_memory = 62 [default = true]; - // bandwidth of ethernet, Bytes per second, default is 1 Gbps - optional int32 bandwidth = 80 [default = 134217728]; // poll time in milliseconds optional int32 poll_time = 81 [default = 100]; } diff --git a/src/trainer/server.cc b/src/trainer/server.cc index b4c386ffc4..601a837904 100644 --- a/src/trainer/server.cc +++ b/src/trainer/server.cc @@ -18,15 +18,22 @@ Server::Server(int thread_id,int group_id, int server_id): } void Server::Setup(const UpdaterProto& proto, - std::unordered_map* shard, - const vector& slice2group) { + const vector& slice2group, + const vector& slice2server) { updater_ = Updater::Create(proto); - shard_ = shard; slice2group_ = slice2group; + slice2server_ = slice2server; + nUpdates_.resize(slice2group_.size(), 0); + nPendingSync_.resize(slice2group_.size(), 0); + last_sync_.resize(slice2group_.size()); } Server::~Server() { delete updater_; + // free Params (i.e., slices) in server shard + for (auto entry : shard_) + for (auto param : entry.second->shares) + delete param; } void Stop(void * running) { @@ -35,6 +42,7 @@ void Stop(void * running) { void Server::Run() { LOG(ERROR) << "Server (group = " << grp_id_ <<", id = " << id_ << ") start"; + auto dealer = new Dealer(2*thread_id_); CHECK(dealer->Connect(kInprocRouterEndpoint)); Msg* ping = new Msg(Addr(grp_id_, id_, kServer), Addr(-1, -1, kStub)); @@ -44,13 +52,10 @@ void Server::Run() { auto cluster = Cluster::Get(); bool running = true; CHECK(cluster->runtime()->WatchSGroup(grp_id_, id_, Stop, &running)); - - int nserver_grps = cluster->nserver_groups(); - vector master_params; - size_t syncEntry=0; Poller poll(dealer); // start recv loop and process requests while (running) { + // must use poller here; otherwise Receive() gets stuck after workers stop. auto *sock = poll.Wait(cluster->poll_time()); if (poll.Terminated()) { LOG(ERROR) << "Connection broken!"; @@ -58,34 +63,18 @@ void Server::Run() { } else if (sock == nullptr) { continue; } - Msg* msg=dealer->Receive(); - if (msg==nullptr) break; - Msg* response=nullptr; - int type=msg->type(); + Msg* msg = dealer->Receive(); + if (msg == nullptr) break; // interrupted + Msg* response = nullptr; + int type = msg->type(); int slice_id = SliceID(msg->trgt_val()); if (type == kPut) { response = HandlePut(&msg); - if(slice2group_[slice_id] == grp_id_) - master_params.push_back(shard_->at(slice_id)->shares.at(0)); } else { - if (shard_->find(slice_id) == shard_->end()) { - // delay the processing by re-queue the msg. + if (shard_.find(slice_id) == shard_.end()) { + // delay the processing by re-queue the msg. May sleep for a while? response = msg; - } else if (type == kSyncReminder) { - DeleteMsg(&msg); - if(syncEntry >= master_params.size()) - continue; - auto param = master_params.at(syncEntry); - // control the frequency of synchronization - // currently sync is triggerred only when the slice is updated - // by local worker or other workers for at least nserver_groups times. - // TODO may optimize the trigger condition. - if (abs(param->local_version() - param->version()) >= nserver_grps) { - for (auto msg : GenSyncMsgs(param)) - dealer->Send(&msg); - syncEntry = (syncEntry+1) % master_params.size(); - } - } else { + } else { switch (type) { case kGet: response = HandleGet(&msg); @@ -97,6 +86,9 @@ void Server::Run() { case kSyncRequest: response = HandleSyncRequest(&msg); break; + case kSyncResponse: + HandleSyncResponse(&msg); + break; default: LOG(ERROR)<<"Unknown message type "< Server::GenSyncMsgs(Param* param) { - vector ret; - // TODO replace the argument (0,0) to sync a chunk instead of a slice - auto msg = param->GenSyncMsg(0, 0); - auto cluster = Cluster::Get(); - for (int i = 0; i < cluster->nserver_groups(); i++) { - if (i != grp_id_) { - Msg* tmp = msg; - if (i < cluster->nserver_groups() - 1) - tmp = new Msg(*msg); - // assume only one server per group, TODO generalize it - tmp->set_dst(Addr(i, 0, kServer)); - tmp->set_src(Addr(grp_id_, id_, kServer)); - ret.push_back(tmp); - param->set_version(param->local_version()); - //LOG(ERROR)<<"sync slice="<id()<<" to procs "<trgt_version(); int slice_id = SliceID((*msg)->trgt_val()); - if (shard_->find(slice_id) != shard_->end()) + if (shard_.find(slice_id) != shard_.end()) LOG(FATAL) << "Param (" << slice_id << ") is put more than once"; // TODO(wangwei) replace hard coded param type 0 @@ -152,17 +123,15 @@ Msg* Server::HandlePut(Msg **msg) { if ((*msg)->NextFrame()) (*msg)->ParseFormatFrame("i", &num_shares); DeleteMsg(msg); - (*shard_)[slice_id] = new ParamEntry(num_shares, param); + shard_[slice_id] = new ParamEntry(num_shares, param); // must set version after HandlePutMsg which allocates the memory param->set_version(version); param->set_local_version(version); param->set_id(slice_id); - //LOG(ERROR)<<"put norm "<data().asum_data()<<", "<nserver_groups() > 1 && slice2group_[slice_id] != grp_id_) { - last_data_[slice_id] = std::make_shared>(); - last_data_[slice_id]->ReshapeLike(param->data()); - last_data_[slice_id]->CopyFrom(param->data()); + if (slice2group_[slice_id] != grp_id_) { + last_sync_[slice_id].ReshapeLike(param->data()); + last_sync_[slice_id].CopyFrom(param->data()); } LOG(INFO)<<"server (group = " << grp_id_ << ", id = " << id_ <<") put slice=" << slice_id << " size=" << param->size(); @@ -171,7 +140,7 @@ Msg* Server::HandlePut(Msg **msg) { Msg* Server::HandleGet(Msg **msg) { int val = (*msg)->trgt_val(); - auto param = shard_->at(SliceID(val))->shares.at(0); + auto param = shard_.at(SliceID(val))->shares.at(0); // re-queue the request if the param is not updated to the required version if(param->version()<(*msg)->trgt_version()) return *msg; @@ -186,15 +155,14 @@ Msg* Server::HandleGet(Msg **msg) { const vector Server::HandleUpdate(Msg **msg) { vector ret; int sliceid = SliceID((*msg)->trgt_val()); - auto entry = shard_->at(sliceid); + auto entry = shard_.at(sliceid); buffer_requests_[sliceid].push_back(*msg); int num_update; (*msg)->LastFrame(); (*msg)->ParseFormatFrame("i", &num_update); (*msg)->FirstFrame(); entry->num_update += num_update; - // LOG(ERROR) << "update "<src_second() - // << ", " << num_update << " total " << entry->num_total; + // LOG(ERROR) << "update "<< sliceid << " from " << AddrGrp((*msg)->src()) << ", " << num_update << " total " << entry->num_total; // do update until recv gradients from all shares of this param/slice if (entry->num_update >= entry->num_total) { CHECK_EQ(entry->num_update, entry->num_total); @@ -211,6 +179,26 @@ const vector Server::HandleUpdate(Msg **msg) { ret.push_back(response); } entry->num_update = 0; + nUpdates_[sliceid]++; + // sync with master group after at least sync_freq local updates + // the last check is to avoid sending msg to stopped servers + if (slice2group_[sliceid] != grp_id_ + && nUpdates_[sliceid] >= Cluster::Get()->sync_freq() + && nPendingSync_[sliceid] <= Cluster::Get()->sync_freq()) { + auto shape = Shape1(param->size()); + Tensor tmp(last_sync_[sliceid].mutable_cpu_data(), shape); + Tensor cur(param->mutable_cpu_data(), shape); + tmp = cur - tmp; + int addr = Addr(slice2group_[sliceid], slice2server_[sliceid], kServer); + Msg* sync = new Msg(Addr(grp_id_, id_, kServer), addr); + sync->set_type(kSyncRequest); + sync->set_trgt((*msg)->trgt_val(), param->local_version()); + sync->AddFrame(tmp.dptr, param->size() * sizeof(float)); + Copy(tmp, cur); + ret.push_back(sync); + nUpdates_[sliceid] = 0; + nPendingSync_[sliceid]++; + } } *msg = nullptr; return ret; @@ -219,38 +207,33 @@ const vector Server::HandleUpdate(Msg **msg) { Msg* Server::HandleSyncRequest(Msg **msg) { Msg* msgg = *msg; int slice = SliceID(msgg->trgt_val()); - auto param = shard_->at(slice)->shares.at(0); - Msg* response=nullptr; - auto shape=Shape1(param->size()); + auto param = shard_.at(slice)->shares.at(0); + auto shape = Shape1(param->size()); CHECK_EQ(msgg->FrameSize(), param->size()*sizeof(float)); - Tensor tmp(static_cast(msgg->FrameData()), shape); + Tensor inc(static_cast(msgg->FrameData()), shape); Tensor cur(param->mutable_cpu_data(), shape); - //LOG(ERROR)<<"Recv sync for "<id(); - if (slice2group_[slice] == grp_id_) { - // recv sync msg on slice I am mastering - cur+=tmp; - param->set_local_version(param->local_version()+1); - } else { // recv sync msg on slice mastered by others - TensorContainer diff(shape); - Tensor prev(last_data_[param->id()]->mutable_cpu_data(), shape); - diff=cur-prev; - msgg->NextFrame(); - int bandwidth; - msgg->ParseFormatFrame("i", &bandwidth); - if (bandwidth > 0) { - // send back my updates to the server group mastering this param - response=new Msg(msgg->dst(), msgg->src()); - response->set_type(kSyncRequest); - response->set_trgt(param->id(), param->version()); - response->AddFrame(diff.dptr, param->size()*sizeof(float)); - prev=diff+tmp; - Copy(cur, prev); - } else { // no bandwidth, aggregate my updates for next sync - Copy(prev, tmp); - cur=tmp+diff; - } - } + // recv sync msg on the slice I am maintaining + cur += inc; + msgg->SwapAddr(); + msgg->set_type(kSyncResponse); + // copy the fresh param value into the response msg + Copy(inc, cur); + return msgg; +} + +// recv sync msg on slice mastered by others +void Server::HandleSyncResponse(Msg **msg) { + Msg* msgg = *msg; + int slice = SliceID(msgg->trgt_val()); + auto param = shard_.at(slice)->shares.at(0); + auto shape=Shape1(param->size()); + Tensor prev(last_sync_[param->id()].mutable_cpu_data(), shape); + Tensor cur(param->mutable_cpu_data(), shape); + Tensor master(static_cast(msgg->FrameData()), shape); + cur += master - prev; // cur = master + (cur - prev); + Copy(prev, cur); DeleteMsg(msg); - return response; + nPendingSync_[slice]--; } + } /* singa */ diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc index b6dc729f4e..b02ef3ec1c 100644 --- a/src/trainer/trainer.cc +++ b/src/trainer/trainer.cc @@ -21,10 +21,6 @@ using std::make_shared; /***********************Trainer****************************/ Trainer::~Trainer() { - // free Params (i.e., slices) in server shard - for (auto entry : server_shard_) - for (auto param : entry.second->shares) - delete param; delete router_; } @@ -120,10 +116,11 @@ void Trainer::SetupWorkerServer( // partition among server groups, each group maintains one sub-set for sync auto slice2group = PartitionSlices(cluster->nserver_groups(), slices); - for (auto server : servers) - server->Setup(job_conf.updater(), &server_shard_, slice2group); // partition within one server group, each server updates for one sub-set slice2server_ = PartitionSlices(cluster->nservers_per_group(), slices); + + for (auto server : servers) + server->Setup(job_conf.updater(), slice2group, slice2server_); } vector Trainer::CreateServers(int nthreads, const JobProto& job) { @@ -132,46 +129,33 @@ vector Trainer::CreateServers(int nthreads, const JobProto& job) { if (!cluster->has_server()) return servers; - int pid = cluster->procs_id(); + int server_procs = cluster->procs_id(); // if true, server procs (logical) id starts after worker procs if (cluster->server_worker_separate()) - pid -= cluster->nworker_procs(); - int procs_size = cluster->nservers_per_procs(); - int grp_size = cluster->nservers_per_group(); - int gid = pid * procs_size / grp_size; - int start = pid * procs_size % grp_size; - int end = start + procs_size; - for (int sid = start; sid < end; sid++) { - auto server = new Server(nthreads++, gid, sid); - servers.push_back(server); + server_procs -= cluster->nworker_procs(); + const vector rng = cluster->ExecutorRng(server_procs, + cluster->nservers_per_group(), + cluster->nservers_per_procs()); + int gstart = rng[0], gend = rng[1], start = rng[2], end = rng[3]; + for (int gid = gstart; gid < gend; gid++) { + for (int sid = start; sid < end; sid++) { + auto server = new Server(nthreads++, gid, sid); + servers.push_back(server); + } } return servers; } + vector Trainer::CreateWorkers(int nthreads, const JobProto& job) { auto cluster=Cluster::Get(); vector workers; if(!cluster->has_worker()) return workers; - int pid = cluster->procs_id(); - int grp_size = cluster->nworkers_per_group(); - int procs_size = cluster->nworkers_per_procs(); - int gstart, gend, wstart, wend; - if (grp_size >= procs_size) { - // all workers in this procs are from the same group - gstart = pid * procs_size / grp_size; - gend = gstart + 1; - wstart = pid * procs_size % grp_size; - wend = wstart + procs_size; - } else { - // there are multiple (complete) groups in this procs. - CHECK_EQ(procs_size % grp_size, 0); - int groups_per_procs = procs_size / grp_size; - gstart = pid * groups_per_procs; - gend = (pid+1) * groups_per_procs; - wstart = 0; - wend = grp_size; - } + const vector rng = cluster->ExecutorRng(cluster->procs_id(), + cluster->nworkers_per_group(), + cluster->nworkers_per_procs()); + int gstart = rng[0], gend = rng[1], wstart = rng[2], wend = rng[3]; for (int gid = gstart; gid < gend; gid++) { for (int wid = wstart; wid < wend; wid++) { auto *worker = Worker::Create(job); @@ -260,12 +244,6 @@ void Trainer::Start(bool resume, const SingaProto& singaConf, JobProto* job) { delete worker; } -inline int bandwidth(int bytes, system_clock::time_point start) { - auto now=system_clock::now(); - auto duration=duration_cast (now - start); - return static_cast(bytes*1000.f/duration.count()); -} - void Trainer::Run( const vector& workers, const vector& servers) { @@ -274,42 +252,20 @@ void Trainer::Run( procs_id_ = cluster->procs_id(); LOG(INFO) << "Stub in process " << procs_id_ << " starts"; - // for sync among server groups - auto start = std::chrono::system_clock::now(); - float trans_size = 0.f; // total size of msg transferred since start time - int sync_server_id = 0; - int max_bandwidth = cluster->bandwidth(); - int nserver_grps = cluster->nserver_groups(); - map inter_dealers; // for sending msg to other procs std::queue msg_queue; - Poller poll(router_); - bool stop=false; - while (!stop || !msg_queue.empty()) { + while (true) { + Msg* msg = nullptr; if (msg_queue.empty()) { - // if the poll time is large, then the poller may not expire - // if it is small, then many reminder messages will be sent which may - // slow done the process of other request. TODO tune it. - auto *sock = poll.Wait(cluster->poll_time()); - if (poll.Terminated()) { - LOG(ERROR) << "Connection broken!"; - exit(0); - } else if (sock == nullptr) { - if (nserver_grps > 1 && bandwidth(trans_size, start) < max_bandwidth) { - Msg* msg = GenSyncReminderMsg(sync_server_id, servers); - router_->Send(&msg) ; - sync_server_id = (sync_server_id + 1) % nservers; - } - continue; - } - Msg* msg = router_->Receive(); - msg_queue.push(msg); + msg = router_->Receive(); + } else { + msg = msg_queue.front(); + msg_queue.pop(); } - Msg* msg = msg_queue.front(); - msg_queue.pop(); int type = msg->type(), dst = msg->dst(), flag = AddrType(dst); if (flag == kStub && (AddrProc(dst) == procs_id_ || AddrGrp(dst) == -1)) { + // the following statements are ordered! if (type == kConnect) { DeleteMsg(&msg); } else if (type == kMetric) { @@ -320,28 +276,18 @@ void Trainer::Run( else if (src_flag == kWorkerParam) nworkers--; DeleteMsg(&msg); if (nworkers == 0 && nservers == 0) break; - } else if (nserver_grps > 0) { - HandleLocalMsg(&msg_queue, &msg); } else { - DeleteMsg(&msg); + HandleLocalMsg(&msg_queue, &msg); } } else { int dst_procs = AddrProc(dst); if (flag != kStub) dst_procs = cluster->ProcsIDOf(AddrGrp(dst), AddrID(dst), flag); if (dst_procs != procs_id_) { - if (bandwidth(trans_size, start) <= cluster->bandwidth()) { - start = std::chrono::system_clock::now(); - trans_size = 0; - } - trans_size += msg->size(); - if (inter_dealers.find(dst_procs) == inter_dealers.end()) inter_dealers[dst_procs] = CreateInterProcsDealer(dst_procs); inter_dealers[dst_procs]->Send(&msg); } else { - if (type == kSyncRequest) - msg->AddFormatFrame("i", max_bandwidth - bandwidth(trans_size, start)); router_->Send(&msg); } } @@ -351,14 +297,6 @@ void Trainer::Run( delete entry.second; } -Msg* Trainer::GenSyncReminderMsg(int server, const vector& servers ) { - Msg* msg = new Msg(); - msg->set_src(Addr(-1,-1, kStub)); - msg->set_dst(Addr(servers[server]->grp_id(), servers[server]->id(), kServer)); - msg->set_type(kSyncReminder); - return msg; -} - void Trainer::DisplayMetric(Msg** msg) { Msg* msgg = *msg; // only display metrics from the first group @@ -436,16 +374,16 @@ void Trainer::GenMsgs(int type, int version, ParamEntry* entry, for (int idx = 0 ; idx < param->num_slices(); idx++) { int slice_id =param->slice_start() + idx; int server = slice2server_[slice_id]; - int procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer); + int dst_procs = Cluster::Get()->ProcsIDOf(dst_grp, server, kServer); Msg* new_msg = nullptr; if (type == kPut) { CHECK_GT(entry->num_total, 0); - new_msg = param->GenPutMsg(procs != procs_id_, idx); + new_msg = param->GenPutMsg(dst_procs != procs_id_, idx); new_msg->AddFormatFrame("i", entry->num_total); } else if (type == kGet) { - new_msg = param->GenGetMsg(procs != procs_id_, idx); + new_msg = param->GenGetMsg(dst_procs != procs_id_, idx); } else if (type == kUpdate) { - new_msg = param->GenUpdateMsg(procs != procs_id_, idx); + new_msg = param->GenUpdateMsg(dst_procs != procs_id_, idx); new_msg->AddFormatFrame("i", entry->num_local); } else { LOG(FATAL) << "Wrong type"; diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc index 9664064fab..a1716a128b 100644 --- a/src/utils/cluster.cc +++ b/src/utils/cluster.cc @@ -6,6 +6,7 @@ #include namespace singa { +using std::vector; Cluster* Cluster::Setup(int job, const SingaProto& singaConf, const ClusterProto& clusterConf) { @@ -71,6 +72,26 @@ void Cluster::SetupFolders(const ClusterProto &cluster) { mkdir(checkpoint_folder().c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); } +const vector Cluster::ExecutorRng(int pid, int grp_size, int procs_size) { + int gstart, gend, start, end; + if (grp_size >= procs_size) { + // all workers in this procs are from the same group + gstart = pid * procs_size / grp_size; + gend = gstart + 1; + start = pid * procs_size % grp_size; + end = start + procs_size; + } else { + // there are multiple (complete) groups in this procs. + CHECK_EQ(procs_size % grp_size, 0); + int groups_per_procs = procs_size / grp_size; + gstart = pid * groups_per_procs; + gend = (pid+1) * groups_per_procs; + start = 0; + end = grp_size; + } + return vector{gstart, gend, start, end}; +} + int Cluster::Hash(int gid, int id, int flag) { int ret = -1; if (flag == kServer) {