-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Build a simple event loop for collective.
- Loading branch information
1 parent
0df1da2
commit b5ffa83
Showing
5 changed files
with
318 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#include "loop.h" | ||
|
||
#include <cinttypes> // for int32_t | ||
#include <queue> // for queue | ||
|
||
#include "rabit/internal/socket.h" // for PullHelper | ||
|
||
namespace xgboost::collective { | ||
Result Loop::EmptyQueue() { | ||
timer_.Start(__func__); | ||
auto error = [this] { | ||
this->stop_ = true; | ||
timer_.Stop(__func__); | ||
}; | ||
|
||
while (!queue_.empty() && !stop_) { | ||
std::queue<Op> qcopy; | ||
rabit::utils::PollHelper poll; | ||
|
||
// watch all ops | ||
while (!queue_.empty()) { | ||
auto op = queue_.front(); | ||
queue_.pop(); | ||
|
||
switch (op.code) { | ||
case Op::kRead: { | ||
poll.WatchRead(*op.sock); | ||
break; | ||
} | ||
case Op::kWrite: { | ||
poll.WatchWrite(*op.sock); | ||
break; | ||
} | ||
default: { | ||
error(); | ||
return Fail("Invalid socket operation."); | ||
} | ||
} | ||
qcopy.push(op); | ||
} | ||
|
||
// poll, work on fds that are ready. | ||
timer_.Start("poll"); | ||
auto rc = poll.Poll(timeout_); | ||
timer_.Stop("poll"); | ||
if (!rc.OK()) { | ||
error(); | ||
return rc; | ||
} | ||
// we wonldn't be here if the queue is empty. | ||
CHECK(!qcopy.empty()); | ||
|
||
while (!qcopy.empty() && !stop_) { | ||
auto op = qcopy.front(); | ||
qcopy.pop(); | ||
|
||
std::int32_t n_bytes_done{0}; | ||
CHECK(op.sock->NonBlocking()); | ||
|
||
switch (op.code) { | ||
case Op::kRead: { | ||
if (poll.CheckRead(*op.sock)) { | ||
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off); | ||
} | ||
break; | ||
} | ||
case Op::kWrite: { | ||
if (poll.CheckWrite(*op.sock)) { | ||
n_bytes_done = op.sock->Send(op.ptr + op.off, op.n - op.off); | ||
} | ||
break; | ||
} | ||
default: { | ||
error(); | ||
return Fail("Invalid socket operation."); | ||
} | ||
} | ||
|
||
if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) { | ||
stop_ = true; | ||
auto rc = system::FailWithCode("Invalid socket output."); | ||
error(); | ||
return rc; | ||
} | ||
op.off += n_bytes_done; | ||
CHECK_LE(op.off, op.n); | ||
|
||
if (op.off != op.n) { | ||
// not yet finished, push back to queue for next round. | ||
queue_.push(op); | ||
} | ||
} | ||
} | ||
timer_.Stop(__func__); | ||
return Success(); | ||
} | ||
|
||
void Loop::Process() { | ||
// consumer | ||
while (true) { | ||
std::unique_lock lock{mu_}; | ||
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); | ||
if (stop_) { | ||
break; | ||
} | ||
CHECK(!mu_.try_lock()); | ||
|
||
this->rc_ = this->EmptyQueue(); | ||
if (!rc_.OK()) { | ||
stop_ = true; | ||
cv_.notify_one(); | ||
break; | ||
} | ||
|
||
CHECK(queue_.empty()); | ||
CHECK(!mu_.try_lock()); | ||
cv_.notify_one(); | ||
} | ||
|
||
if (rc_.OK()) { | ||
CHECK(queue_.empty()); | ||
} | ||
} | ||
|
||
Result Loop::Stop() { | ||
std::unique_lock lock{mu_}; | ||
stop_ = true; | ||
lock.unlock(); | ||
|
||
CHECK_EQ(this->Block().OK(), this->rc_.OK()); | ||
|
||
if (curr_exce_) { | ||
std::rethrow_exception(curr_exce_); | ||
} | ||
|
||
return Success(); | ||
} | ||
|
||
Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { | ||
timer_.Init(__func__); | ||
worker_ = std::thread{[this] { | ||
try { | ||
this->Process(); | ||
} catch (std::exception const& e) { | ||
std::lock_guard<std::mutex> guard{mu_}; | ||
if (!curr_exce_) { | ||
curr_exce_ = std::current_exception(); | ||
rc_ = Fail("Exception was thrown"); | ||
} | ||
stop_ = true; | ||
cv_.notify_all(); | ||
} catch (...) { | ||
std::lock_guard<std::mutex> guard{mu_}; | ||
if (!curr_exce_) { | ||
curr_exce_ = std::current_exception(); | ||
rc_ = Fail("Exception was thrown"); | ||
} | ||
stop_ = true; | ||
cv_.notify_all(); | ||
} | ||
}}; | ||
} | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#pragma once | ||
#include <chrono> // for seconds | ||
#include <condition_variable> // for condition_variable | ||
#include <cstdint> // for int8_t, int32_t | ||
#include <mutex> // for mutex, lock_guard, unique_lock | ||
#include <queue> // for queue | ||
#include <thread> // for thread | ||
#include <utility> // for move | ||
|
||
#include "../common/timer.h" // for Monitor | ||
#include "xgboost/collective/socket.h" // for TCPSocket | ||
|
||
namespace xgboost::collective { | ||
class Loop { | ||
public: | ||
struct Op { | ||
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code; | ||
std::int32_t rank{-1}; | ||
std::int8_t* ptr{nullptr}; | ||
std::size_t n{0}; | ||
TCPSocket* sock{nullptr}; | ||
std::size_t off{0}; | ||
|
||
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) | ||
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} | ||
Op(Op const&) = default; | ||
Op& operator=(Op const&) = default; | ||
Op(Op&&) = default; | ||
Op& operator=(Op&&) = default; | ||
}; | ||
|
||
private: | ||
std::thread worker_; | ||
std::condition_variable cv_; | ||
std::mutex mu_; | ||
std::queue<Op> queue_; | ||
std::chrono::seconds timeout_; | ||
Result rc_; | ||
bool stop_{false}; | ||
std::exception_ptr curr_exce_{nullptr}; | ||
common::Monitor timer_; | ||
|
||
Result EmptyQueue(); | ||
void Process(); | ||
|
||
public: | ||
Result Stop(); | ||
|
||
void Submit(Op op) { | ||
// producer | ||
std::unique_lock lock{mu_}; | ||
queue_.push(op); | ||
lock.unlock(); | ||
cv_.notify_one(); | ||
} | ||
|
||
[[nodiscard]] Result Block() { | ||
{ | ||
std::unique_lock lock{mu_}; | ||
cv_.notify_all(); | ||
} | ||
std::unique_lock lock{mu_}; | ||
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; }); | ||
return std::move(rc_); | ||
} | ||
|
||
explicit Loop(std::chrono::seconds timeout); | ||
|
||
~Loop() noexcept(false) { | ||
this->Stop(); | ||
|
||
if (worker_.joinable()) { | ||
worker_.join(); | ||
} | ||
} | ||
}; | ||
} // namespace xgboost::collective |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/** | ||
* Copyright 2023, XGBoost Contributors | ||
*/ | ||
#include <gtest/gtest.h> | ||
|
||
#include <chrono> // for seconds | ||
#include <cinttypes> // for int8_t | ||
#include <memory> // for make_shared | ||
#include <system_error> // make_error_code, errc | ||
|
||
#include "../../../src/collective/loop.h" | ||
|
||
namespace xgboost::collective { | ||
TEST(Loop, Timeout) { | ||
system::SocketStartup(); | ||
std::chrono::seconds timeout{1}; | ||
auto loop = std::make_shared<Loop>(timeout); | ||
|
||
TCPSocket sock; | ||
std::vector<std::int8_t> data(1); | ||
Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &sock, 0}; | ||
loop->Submit(op); | ||
auto rc = loop->Block(); | ||
ASSERT_FALSE(rc.OK()); | ||
ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)); | ||
system::SocketFinalize(); | ||
} | ||
|
||
TEST(Loop, Op) { | ||
system::SocketStartup(); | ||
|
||
auto domain = SockDomain::kV4; | ||
auto server = TCPSocket::Create(domain); | ||
auto port = server.BindHost(); | ||
server.Listen(); | ||
|
||
TCPSocket send; | ||
auto const& addr = SockAddrV4::Loopback().Addr(); | ||
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &send); | ||
ASSERT_TRUE(rc.OK()) << rc.Report(); | ||
rc = send.NonBlocking(true); | ||
ASSERT_TRUE(rc.OK()) << rc.Report(); | ||
|
||
auto recv = server.Accept(); | ||
rc = recv.NonBlocking(true); | ||
ASSERT_TRUE(rc.OK()) << rc.Report(); | ||
|
||
std::vector<std::int8_t> wbuf(1, 1); | ||
std::vector<std::int8_t> rbuf(1, 0); | ||
|
||
std::chrono::seconds timeout{1}; | ||
auto loop = std::make_shared<Loop>(timeout); | ||
|
||
Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0}; | ||
Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0}; | ||
|
||
loop->Submit(wop); | ||
loop->Submit(rop); | ||
|
||
rc = loop->Block(); | ||
ASSERT_TRUE(rc.OK()) << rc.Report(); | ||
|
||
ASSERT_EQ(rbuf[0], wbuf[0]); | ||
|
||
system::SocketFinalize(); | ||
} | ||
} // namespace xgboost::collective |