Skip to content

Commit

Permalink
Build a simple event loop for collective.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 18, 2023
1 parent 0df1da2 commit b5ffa83
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 7 deletions.
7 changes: 3 additions & 4 deletions rabit/include/rabit/internal/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#include <string>
#include <vector>

#include "rabit/internal/utils.h"
#include "rabit/serializable.h"
#include "dmlc/io.h"
#include "xgboost/logging.h"

namespace rabit::utils {
/*! \brief re-use definition of dmlc::SeekStream */
Expand Down Expand Up @@ -84,8 +84,7 @@ struct MemoryBufferStream : public SeekStream {
}
~MemoryBufferStream() override = default;
size_t Read(void *ptr, size_t size) override {
utils::Assert(curr_ptr_ <= p_buffer_->length(),
"read can not have position excceed buffer length");
CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length";
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
Expand Down
5 changes: 2 additions & 3 deletions rabit/include/rabit/internal/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@
#include <chrono>
#include <cstring>
#include <string>
#include <system_error> // make_error_code, errc
#include <unordered_map>
#include <vector>

#include "utils.h"

#if !defined(_WIN32)

#include <sys/poll.h>
Expand Down Expand Up @@ -168,7 +167,7 @@ struct PollHelper {
}
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) {
return xgboost::collective::Fail("Poll timeout.");
return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out));
} else if (ret < 0) {
return xgboost::system::FailWithCode("Poll failed.");
} else {
Expand Down
166 changes: 166 additions & 0 deletions src/collective/loop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "loop.h"

#include <cinttypes> // for int32_t
#include <queue> // for queue

#include "rabit/internal/socket.h" // for PullHelper

namespace xgboost::collective {
Result Loop::EmptyQueue() {
timer_.Start(__func__);
auto error = [this] {
this->stop_ = true;
timer_.Stop(__func__);
};

while (!queue_.empty() && !stop_) {
std::queue<Op> qcopy;
rabit::utils::PollHelper poll;

// watch all ops
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();

switch (op.code) {
case Op::kRead: {
poll.WatchRead(*op.sock);
break;
}
case Op::kWrite: {
poll.WatchWrite(*op.sock);
break;
}
default: {
error();
return Fail("Invalid socket operation.");
}
}
qcopy.push(op);
}

// poll, work on fds that are ready.
timer_.Start("poll");
auto rc = poll.Poll(timeout_);
timer_.Stop("poll");
if (!rc.OK()) {
error();
return rc;
}
// we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());

while (!qcopy.empty() && !stop_) {
auto op = qcopy.front();
qcopy.pop();

std::int32_t n_bytes_done{0};
CHECK(op.sock->NonBlocking());

switch (op.code) {
case Op::kRead: {
if (poll.CheckRead(*op.sock)) {
n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off);
}
break;
}
case Op::kWrite: {
if (poll.CheckWrite(*op.sock)) {
n_bytes_done = op.sock->Send(op.ptr + op.off, op.n - op.off);
}
break;
}
default: {
error();
return Fail("Invalid socket operation.");
}
}

if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
stop_ = true;
auto rc = system::FailWithCode("Invalid socket output.");
error();
return rc;
}
op.off += n_bytes_done;
CHECK_LE(op.off, op.n);

if (op.off != op.n) {
// not yet finished, push back to queue for next round.
queue_.push(op);
}
}
}
timer_.Stop(__func__);
return Success();
}

void Loop::Process() {
// consumer
while (true) {
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; });
if (stop_) {
break;
}
CHECK(!mu_.try_lock());

this->rc_ = this->EmptyQueue();
if (!rc_.OK()) {
stop_ = true;
cv_.notify_one();
break;
}

CHECK(queue_.empty());
CHECK(!mu_.try_lock());
cv_.notify_one();
}

if (rc_.OK()) {
CHECK(queue_.empty());
}
}

Result Loop::Stop() {
std::unique_lock lock{mu_};
stop_ = true;
lock.unlock();

CHECK_EQ(this->Block().OK(), this->rc_.OK());

if (curr_exce_) {
std::rethrow_exception(curr_exce_);
}

return Success();
}

Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__);
worker_ = std::thread{[this] {
try {
this->Process();
} catch (std::exception const& e) {
std::lock_guard<std::mutex> guard{mu_};
if (!curr_exce_) {
curr_exce_ = std::current_exception();
rc_ = Fail("Exception was thrown");
}
stop_ = true;
cv_.notify_all();
} catch (...) {
std::lock_guard<std::mutex> guard{mu_};
if (!curr_exce_) {
curr_exce_ = std::current_exception();
rc_ = Fail("Exception was thrown");
}
stop_ = true;
cv_.notify_all();
}
}};
}
} // namespace xgboost::collective
80 changes: 80 additions & 0 deletions src/collective/loop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
#include <condition_variable> // for condition_variable
#include <cstdint> // for int8_t, int32_t
#include <mutex> // for mutex, lock_guard, unique_lock
#include <queue> // for queue
#include <thread> // for thread
#include <utility> // for move

#include "../common/timer.h" // for Monitor
#include "xgboost/collective/socket.h" // for TCPSocket

namespace xgboost::collective {
class Loop {
public:
struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};

Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
Op& operator=(Op const&) = default;
Op(Op&&) = default;
Op& operator=(Op&&) = default;
};

private:
std::thread worker_;
std::condition_variable cv_;
std::mutex mu_;
std::queue<Op> queue_;
std::chrono::seconds timeout_;
Result rc_;
bool stop_{false};
std::exception_ptr curr_exce_{nullptr};
common::Monitor timer_;

Result EmptyQueue();
void Process();

public:
Result Stop();

void Submit(Op op) {
// producer
std::unique_lock lock{mu_};
queue_.push(op);
lock.unlock();
cv_.notify_one();
}

[[nodiscard]] Result Block() {
{
std::unique_lock lock{mu_};
cv_.notify_all();
}
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
return std::move(rc_);
}

explicit Loop(std::chrono::seconds timeout);

~Loop() noexcept(false) {
this->Stop();

if (worker_.joinable()) {
worker_.join();
}
}
};
} // namespace xgboost::collective
67 changes: 67 additions & 0 deletions tests/cpp/collective/test_loop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>

#include <chrono> // for seconds
#include <cinttypes> // for int8_t
#include <memory> // for make_shared
#include <system_error> // make_error_code, errc

#include "../../../src/collective/loop.h"

namespace xgboost::collective {
TEST(Loop, Timeout) {
system::SocketStartup();
std::chrono::seconds timeout{1};
auto loop = std::make_shared<Loop>(timeout);

TCPSocket sock;
std::vector<std::int8_t> data(1);
Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &sock, 0};
loop->Submit(op);
auto rc = loop->Block();
ASSERT_FALSE(rc.OK());
ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out));
system::SocketFinalize();
}

TEST(Loop, Op) {
system::SocketStartup();

auto domain = SockDomain::kV4;
auto server = TCPSocket::Create(domain);
auto port = server.BindHost();
server.Listen();

TCPSocket send;
auto const& addr = SockAddrV4::Loopback().Addr();
auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &send);
ASSERT_TRUE(rc.OK()) << rc.Report();
rc = send.NonBlocking(true);
ASSERT_TRUE(rc.OK()) << rc.Report();

auto recv = server.Accept();
rc = recv.NonBlocking(true);
ASSERT_TRUE(rc.OK()) << rc.Report();

std::vector<std::int8_t> wbuf(1, 1);
std::vector<std::int8_t> rbuf(1, 0);

std::chrono::seconds timeout{1};
auto loop = std::make_shared<Loop>(timeout);

Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0};
Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0};

loop->Submit(wop);
loop->Submit(rop);

rc = loop->Block();
ASSERT_TRUE(rc.OK()) << rc.Report();

ASSERT_EQ(rbuf[0], wbuf[0]);

system::SocketFinalize();
}
} // namespace xgboost::collective

0 comments on commit b5ffa83

Please sign in to comment.