From b5ffa837b597e33c882e372555e8378a9c8462fb Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 19 Sep 2023 05:17:16 +0800 Subject: [PATCH] Build a simple event loop for collective. --- rabit/include/rabit/internal/io.h | 7 +- rabit/include/rabit/internal/socket.h | 5 +- src/collective/loop.cc | 166 ++++++++++++++++++++++++++ src/collective/loop.h | 80 +++++++++++++ tests/cpp/collective/test_loop.cc | 67 +++++++++++ 5 files changed, 318 insertions(+), 7 deletions(-) create mode 100644 src/collective/loop.cc create mode 100644 src/collective/loop.h create mode 100644 tests/cpp/collective/test_loop.cc diff --git a/rabit/include/rabit/internal/io.h b/rabit/include/rabit/internal/io.h index d93f32ff9c07..d5d0fee4d79c 100644 --- a/rabit/include/rabit/internal/io.h +++ b/rabit/include/rabit/internal/io.h @@ -16,8 +16,8 @@ #include #include -#include "rabit/internal/utils.h" -#include "rabit/serializable.h" +#include "dmlc/io.h" +#include "xgboost/logging.h" namespace rabit::utils { /*! \brief re-use definition of dmlc::SeekStream */ @@ -84,8 +84,7 @@ struct MemoryBufferStream : public SeekStream { } ~MemoryBufferStream() override = default; size_t Read(void *ptr, size_t size) override { - utils::Assert(curr_ptr_ <= p_buffer_->length(), - "read can not have position excceed buffer length"); + CHECK_LE(curr_ptr_, p_buffer_->length()) << "read can not have position excceed buffer length"; size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); curr_ptr_ += nread; diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index 6fb7fe725722..535b5708b7f6 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -29,11 +29,10 @@ #include #include #include +#include // make_error_code, errc #include #include -#include "utils.h" - #if !defined(_WIN32) #include @@ -168,7 +167,7 @@ struct PollHelper { } int ret = PollImpl(fdset.data(), fdset.size(), timeout); if (ret == 0) { - return xgboost::collective::Fail("Poll timeout."); + return xgboost::collective::Fail("Poll timeout.", std::make_error_code(std::errc::timed_out)); } else if (ret < 0) { return xgboost::system::FailWithCode("Poll failed."); } else { diff --git a/src/collective/loop.cc b/src/collective/loop.cc new file mode 100644 index 000000000000..a5d1f418d02a --- /dev/null +++ b/src/collective/loop.cc @@ -0,0 +1,166 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include "loop.h" + +#include // for int32_t +#include // for queue + +#include "rabit/internal/socket.h" // for PullHelper + +namespace xgboost::collective { +Result Loop::EmptyQueue() { + timer_.Start(__func__); + auto error = [this] { + this->stop_ = true; + timer_.Stop(__func__); + }; + + while (!queue_.empty() && !stop_) { + std::queue qcopy; + rabit::utils::PollHelper poll; + + // watch all ops + while (!queue_.empty()) { + auto op = queue_.front(); + queue_.pop(); + + switch (op.code) { + case Op::kRead: { + poll.WatchRead(*op.sock); + break; + } + case Op::kWrite: { + poll.WatchWrite(*op.sock); + break; + } + default: { + error(); + return Fail("Invalid socket operation."); + } + } + qcopy.push(op); + } + + // poll, work on fds that are ready. + timer_.Start("poll"); + auto rc = poll.Poll(timeout_); + timer_.Stop("poll"); + if (!rc.OK()) { + error(); + return rc; + } + // we wonldn't be here if the queue is empty. + CHECK(!qcopy.empty()); + + while (!qcopy.empty() && !stop_) { + auto op = qcopy.front(); + qcopy.pop(); + + std::int32_t n_bytes_done{0}; + CHECK(op.sock->NonBlocking()); + + switch (op.code) { + case Op::kRead: { + if (poll.CheckRead(*op.sock)) { + n_bytes_done = op.sock->Recv(op.ptr + op.off, op.n - op.off); + } + break; + } + case Op::kWrite: { + if (poll.CheckWrite(*op.sock)) { + n_bytes_done = op.sock->Send(op.ptr + op.off, op.n - op.off); + } + break; + } + default: { + error(); + return Fail("Invalid socket operation."); + } + } + + if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) { + stop_ = true; + auto rc = system::FailWithCode("Invalid socket output."); + error(); + return rc; + } + op.off += n_bytes_done; + CHECK_LE(op.off, op.n); + + if (op.off != op.n) { + // not yet finished, push back to queue for next round. + queue_.push(op); + } + } + } + timer_.Stop(__func__); + return Success(); +} + +void Loop::Process() { + // consumer + while (true) { + std::unique_lock lock{mu_}; + cv_.wait(lock, [this] { return !this->queue_.empty() || stop_; }); + if (stop_) { + break; + } + CHECK(!mu_.try_lock()); + + this->rc_ = this->EmptyQueue(); + if (!rc_.OK()) { + stop_ = true; + cv_.notify_one(); + break; + } + + CHECK(queue_.empty()); + CHECK(!mu_.try_lock()); + cv_.notify_one(); + } + + if (rc_.OK()) { + CHECK(queue_.empty()); + } +} + +Result Loop::Stop() { + std::unique_lock lock{mu_}; + stop_ = true; + lock.unlock(); + + CHECK_EQ(this->Block().OK(), this->rc_.OK()); + + if (curr_exce_) { + std::rethrow_exception(curr_exce_); + } + + return Success(); +} + +Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} { + timer_.Init(__func__); + worker_ = std::thread{[this] { + try { + this->Process(); + } catch (std::exception const& e) { + std::lock_guard guard{mu_}; + if (!curr_exce_) { + curr_exce_ = std::current_exception(); + rc_ = Fail("Exception was thrown"); + } + stop_ = true; + cv_.notify_all(); + } catch (...) { + std::lock_guard guard{mu_}; + if (!curr_exce_) { + curr_exce_ = std::current_exception(); + rc_ = Fail("Exception was thrown"); + } + stop_ = true; + cv_.notify_all(); + } + }}; +} +} // namespace xgboost::collective diff --git a/src/collective/loop.h b/src/collective/loop.h new file mode 100644 index 000000000000..6768cad9d009 --- /dev/null +++ b/src/collective/loop.h @@ -0,0 +1,80 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once +#include // for seconds +#include // for condition_variable +#include // for int8_t, int32_t +#include // for mutex, lock_guard, unique_lock +#include // for queue +#include // for thread +#include // for move + +#include "../common/timer.h" // for Monitor +#include "xgboost/collective/socket.h" // for TCPSocket + +namespace xgboost::collective { +class Loop { + public: + struct Op { + enum Code : std::int8_t { kRead = 0, kWrite = 1 } code; + std::int32_t rank{-1}; + std::int8_t* ptr{nullptr}; + std::size_t n{0}; + TCPSocket* sock{nullptr}; + std::size_t off{0}; + + Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off) + : code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {} + Op(Op const&) = default; + Op& operator=(Op const&) = default; + Op(Op&&) = default; + Op& operator=(Op&&) = default; + }; + + private: + std::thread worker_; + std::condition_variable cv_; + std::mutex mu_; + std::queue queue_; + std::chrono::seconds timeout_; + Result rc_; + bool stop_{false}; + std::exception_ptr curr_exce_{nullptr}; + common::Monitor timer_; + + Result EmptyQueue(); + void Process(); + + public: + Result Stop(); + + void Submit(Op op) { + // producer + std::unique_lock lock{mu_}; + queue_.push(op); + lock.unlock(); + cv_.notify_one(); + } + + [[nodiscard]] Result Block() { + { + std::unique_lock lock{mu_}; + cv_.notify_all(); + } + std::unique_lock lock{mu_}; + cv_.wait(lock, [this] { return this->queue_.empty() || stop_; }); + return std::move(rc_); + } + + explicit Loop(std::chrono::seconds timeout); + + ~Loop() noexcept(false) { + this->Stop(); + + if (worker_.joinable()) { + worker_.join(); + } + } +}; +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_loop.cc b/tests/cpp/collective/test_loop.cc new file mode 100644 index 000000000000..1ed6900be3bc --- /dev/null +++ b/tests/cpp/collective/test_loop.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include + +#include // for seconds +#include // for int8_t +#include // for make_shared +#include // make_error_code, errc + +#include "../../../src/collective/loop.h" + +namespace xgboost::collective { +TEST(Loop, Timeout) { + system::SocketStartup(); + std::chrono::seconds timeout{1}; + auto loop = std::make_shared(timeout); + + TCPSocket sock; + std::vector data(1); + Loop::Op op{Loop::Op::kRead, 0, data.data(), data.size(), &sock, 0}; + loop->Submit(op); + auto rc = loop->Block(); + ASSERT_FALSE(rc.OK()); + ASSERT_EQ(rc.Code(), std::make_error_code(std::errc::timed_out)); + system::SocketFinalize(); +} + +TEST(Loop, Op) { + system::SocketStartup(); + + auto domain = SockDomain::kV4; + auto server = TCPSocket::Create(domain); + auto port = server.BindHost(); + server.Listen(); + + TCPSocket send; + auto const& addr = SockAddrV4::Loopback().Addr(); + auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &send); + ASSERT_TRUE(rc.OK()) << rc.Report(); + rc = send.NonBlocking(true); + ASSERT_TRUE(rc.OK()) << rc.Report(); + + auto recv = server.Accept(); + rc = recv.NonBlocking(true); + ASSERT_TRUE(rc.OK()) << rc.Report(); + + std::vector wbuf(1, 1); + std::vector rbuf(1, 0); + + std::chrono::seconds timeout{1}; + auto loop = std::make_shared(timeout); + + Loop::Op wop{Loop::Op::kWrite, 0, wbuf.data(), wbuf.size(), &send, 0}; + Loop::Op rop{Loop::Op::kRead, 0, rbuf.data(), rbuf.size(), &recv, 0}; + + loop->Submit(wop); + loop->Submit(rop); + + rc = loop->Block(); + ASSERT_TRUE(rc.OK()) << rc.Report(); + + ASSERT_EQ(rbuf[0], wbuf[0]); + + system::SocketFinalize(); +} +} // namespace xgboost::collective