diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 80c2b8404382..fb7c8dbe69e7 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -32,4 +32,3 @@ formats: python: install: - requirements: doc/requirements.txt - system_packages: true diff --git a/include/xgboost/collective/result.h b/include/xgboost/collective/result.h new file mode 100644 index 000000000000..209362505fc5 --- /dev/null +++ b/include/xgboost/collective/result.h @@ -0,0 +1,160 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once + +#include // for unique_ptr +#include // for stringstream +#include // for stack +#include // for string +#include // for move + +namespace xgboost::collective { +namespace detail { +struct ResultImpl { + std::string message; + std::error_code errc{}; // optional for system error. + + std::unique_ptr prev{nullptr}; + + ResultImpl() = delete; // must initialize. + ResultImpl(ResultImpl const& that) = delete; + ResultImpl(ResultImpl&& that) = default; + ResultImpl& operator=(ResultImpl const& that) = delete; + ResultImpl& operator=(ResultImpl&& that) = default; + + explicit ResultImpl(std::string msg) : message{std::move(msg)} {} + explicit ResultImpl(std::string msg, std::error_code errc) + : message{std::move(msg)}, errc{std::move(errc)} {} + explicit ResultImpl(std::string msg, std::unique_ptr prev) + : message{std::move(msg)}, prev{std::move(prev)} {} + explicit ResultImpl(std::string msg, std::error_code errc, std::unique_ptr prev) + : message{std::move(msg)}, errc{std::move(errc)}, prev{std::move(prev)} {} + + [[nodiscard]] bool operator==(ResultImpl const& that) const noexcept(true) { + if ((prev && !that.prev) || (!prev && that.prev)) { + // one of them doesn't have prev + return false; + } + + auto cur_eq = message == that.message && errc == that.errc; + if (prev && that.prev) { + // recursive comparison + auto prev_eq = *prev == *that.prev; + return cur_eq && prev_eq; + } + return cur_eq; + } + + [[nodiscard]] std::string Report() { + std::stringstream ss; + ss << "\n- " << this->message; + if (this->errc != std::error_code{}) { + ss << " system error:" << this->errc.message(); + } + + auto ptr = prev.get(); + while (ptr) { + ss << "\n- "; + ss << ptr->message; + + if (ptr->errc != std::error_code{}) { + ss << " " << ptr->errc.message(); + } + ptr = ptr->prev.get(); + } + + return ss.str(); + } + [[nodiscard]] auto Code() const { + // Find the root error. + std::stack stack; + auto ptr = this; + while (ptr) { + stack.push(ptr); + if (ptr->prev) { + ptr = ptr->prev.get(); + } else { + break; + } + } + while (!stack.empty()) { + auto frame = stack.top(); + stack.pop(); + if (frame->errc != std::error_code{}) { + return frame->errc; + } + } + return std::error_code{}; + } +}; +} // namespace detail + +/** + * @brief An error type that's easier to handle than throwing dmlc exception. We can + * record and propagate the system error code. + */ +struct Result { + private: + std::unique_ptr impl_{nullptr}; + + public: + Result() noexcept(true) = default; + explicit Result(std::string msg) : impl_{std::make_unique(std::move(msg))} {} + explicit Result(std::string msg, std::error_code errc) + : impl_{std::make_unique(std::move(msg), std::move(errc))} {} + Result(std::string msg, Result&& prev) + : impl_{std::make_unique(std::move(msg), std::move(prev.impl_))} {} + Result(std::string msg, std::error_code errc, Result&& prev) + : impl_{std::make_unique(std::move(msg), std::move(errc), + std::move(prev.impl_))} {} + + Result(Result const& that) = delete; + Result& operator=(Result const& that) = delete; + Result(Result&& that) = default; + Result& operator=(Result&& that) = default; + + [[nodiscard]] bool OK() const noexcept(true) { return !impl_; } + [[nodiscard]] std::string Report() const { return OK() ? "" : impl_->Report(); } + /** + * @brief Return the root system error. This might return success if there's no system error. + */ + [[nodiscard]] auto Code() const { return OK() ? std::error_code{} : impl_->Code(); } + [[nodiscard]] bool operator==(Result const& that) const noexcept(true) { + if (OK() && that.OK()) { + return true; + } + if ((OK() && !that.OK()) || (!OK() && that.OK())) { + return false; + } + return *impl_ == *that.impl_; + } +}; + +/** + * @brief Return success. + */ +[[nodiscard]] inline auto Success() noexcept(true) { return Result{}; } +/** + * @brief Return failure. + */ +[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; } +/** + * @brief Return failure with `errno`. + */ +[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) { + return Result{std::move(msg), std::move(errc)}; +} +/** + * @brief Return failure with a previous error. + */ +[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) { + return Result{std::move(msg), std::forward(prev)}; +} +/** + * @brief Return failure with a previous error and a new `errno`. + */ +[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) { + return Result{std::move(msg), std::move(errc), std::forward(prev)}; +} +} // namespace xgboost::collective diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h index b5fa7cd70e30..5bff2204eb9a 100644 --- a/include/xgboost/collective/socket.h +++ b/include/xgboost/collective/socket.h @@ -56,9 +56,10 @@ using ssize_t = int; #endif // defined(_WIN32) -#include "xgboost/base.h" // XGBOOST_EXPECT -#include "xgboost/logging.h" // LOG -#include "xgboost/string_view.h" // StringView +#include "xgboost/base.h" // XGBOOST_EXPECT +#include "xgboost/collective/result.h" // for Result +#include "xgboost/logging.h" // LOG +#include "xgboost/string_view.h" // StringView #if !defined(HOST_NAME_MAX) #define HOST_NAME_MAX 256 // macos @@ -81,6 +82,10 @@ inline std::int32_t LastError() { #endif } +[[nodiscard]] inline collective::Result FailWithCode(std::string msg) { + return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()}); +} + #if defined(__GLIBC__) inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(), std::int32_t line = __builtin_LINE(), @@ -120,15 +125,19 @@ inline std::int32_t CloseSocket(SocketT fd) { #endif } -inline bool LastErrorWouldBlock() { - int errsv = LastError(); +inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) { #ifdef _WIN32 return errsv == WSAEWOULDBLOCK; #else - return errsv == EAGAIN || errsv == EWOULDBLOCK; + return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS; #endif // _WIN32 } +inline bool LastErrorWouldBlock() { + int errsv = LastError(); + return ErrorWouldBlock(errsv); +} + inline void SocketStartup() { #if defined(_WIN32) WSADATA wsa_data; @@ -315,23 +324,35 @@ class TCPSocket { bool IsClosed() const { return handle_ == InvalidSocket(); } /** \brief get last error code if any */ - std::int32_t GetSockError() const { - std::int32_t error = 0; - socklen_t len = sizeof(error); - xgboost_CHECK_SYS_CALL( - getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len), 0); - return error; + Result GetSockError() const { + std::int32_t optval = 0; + socklen_t len = sizeof(optval); + auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast(&optval), &len); + if (ret != 0) { + auto errc = std::error_code{system::LastError(), std::system_category()}; + return Fail("Failed to retrieve socket error.", std::move(errc)); + } + if (optval != 0) { + auto errc = std::error_code{optval, std::system_category()}; + return Fail("Socket error.", std::move(errc)); + } + return Success(); } + /** \brief check if anything bad happens */ bool BadSocket() const { - if (IsClosed()) return true; - std::int32_t err = GetSockError(); - if (err == EBADF || err == EINTR) return true; + if (IsClosed()) { + return true; + } + auto err = GetSockError(); + if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT + err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT + return true; + } return false; } - void SetNonBlock() { - bool non_block{true}; + void SetNonBlock(bool non_block) { #if defined(_WIN32) u_long mode = non_block ? 1 : 0; xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR); @@ -530,10 +551,20 @@ class TCPSocket { }; /** - * \brief Connect to remote address, returns the error code if failed (no exception is - * raised so that we can retry). + * @brief Connect to remote address, returns the error code if failed. + * + * @param host Host IP address. + * @param port Connection port. + * @param retry Number of retries to attempt. + * @param timeout Timeout of each connection attempt. + * @param out_conn Output socket if the connection is successful. Value is invalid and undefined if + * the connection failed. + * + * @return Connection status. */ -std::error_code Connect(SockAddress const &addr, TCPSocket *out); +[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, + std::chrono::seconds timeout, + xgboost::collective::TCPSocket *out_conn); /** * \brief Get the local host name. diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 41fd6405ab24..2e0933a43697 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -94,6 +94,10 @@ def no_ipv6() -> PytestSkip: return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."} +def not_linux() -> PytestSkip: + return {"condition": system() != "Linux", "reason": "Linux is required."} + + def no_ubjson() -> PytestSkip: return no_mod("ubjson") diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index cb7d4a0784c1..6fb7fe725722 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -1,10 +1,11 @@ -/*! - * Copyright (c) 2014-2022 by XGBoost Contributors +/** + * Copyright 2014-2023, XGBoost Contributors * \file socket.h * \author Tianqi Chen */ #ifndef RABIT_INTERNAL_SOCKET_H_ #define RABIT_INTERNAL_SOCKET_H_ +#include "xgboost/collective/result.h" #include "xgboost/collective/socket.h" #if defined(_WIN32) @@ -77,7 +78,7 @@ namespace rabit { namespace utils { template -int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) { +int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) { #if defined(_WIN32) #if IS_MINGW() @@ -135,11 +136,11 @@ struct PollHelper { * \brief Check if the descriptor is ready for read * \param fd file descriptor to check status */ - inline bool CheckRead(SOCKET fd) const { + [[nodiscard]] bool CheckRead(SOCKET fd) const { const auto& pfd = fds.find(fd); return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); } - bool CheckRead(xgboost::collective::TCPSocket const &socket) const { + [[nodiscard]] bool CheckRead(xgboost::collective::TCPSocket const& socket) const { return this->CheckRead(socket.Handle()); } @@ -147,19 +148,19 @@ struct PollHelper { * \brief Check if the descriptor is ready for write * \param fd file descriptor to check status */ - inline bool CheckWrite(SOCKET fd) const { + [[nodiscard]] bool CheckWrite(SOCKET fd) const { const auto& pfd = fds.find(fd); return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); } - bool CheckWrite(xgboost::collective::TCPSocket const &socket) const { + [[nodiscard]] bool CheckWrite(xgboost::collective::TCPSocket const& socket) const { return this->CheckWrite(socket.Handle()); } - /*! - * \brief perform poll on the set defined, read, write, exception - * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block - * \return + /** + * @brief perform poll on the set defined, read, write, exception + * + * @param timeout specify timeout in seconds. Block if negative. */ - inline void Poll(std::chrono::seconds timeout) { // NOLINT(*) + [[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout) { std::vector fdset; fdset.reserve(fds.size()); for (auto kv : fds) { @@ -167,9 +168,9 @@ struct PollHelper { } int ret = PollImpl(fdset.data(), fdset.size(), timeout); if (ret == 0) { - LOG(FATAL) << "Poll timeout"; + return xgboost::collective::Fail("Poll timeout."); } else if (ret < 0) { - LOG(FATAL) << "Failed to poll."; + return xgboost::system::FailWithCode("Poll failed."); } else { for (auto& pfd : fdset) { auto revents = pfd.revents & pfd.events; @@ -180,6 +181,7 @@ struct PollHelper { } } } + return xgboost::collective::Success(); } std::unordered_map fds; diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index ac08ac12a2cf..bd48d3599109 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -1,5 +1,5 @@ -/*! - * Copyright (c) 2014 by Contributors +/** + * Copyright 2014-2023, XGBoost Contributors * \file allreduce_base.cc * \brief Basic implementation of AllReduce * @@ -9,9 +9,11 @@ #define NOMINMAX #endif // !defined(NOMINMAX) +#include "allreduce_base.h" + #include "rabit/base.h" #include "rabit/internal/rabit-inl.h" -#include "allreduce_base.h" +#include "xgboost/collective/result.h" #ifndef _WIN32 #include @@ -20,8 +22,7 @@ #include #include -namespace rabit { -namespace engine { +namespace rabit::engine { // constructor AllreduceBase::AllreduceBase() { tracker_uri = "NULL"; @@ -116,7 +117,12 @@ bool AllreduceBase::Init(int argc, char* argv[]) { utils::Assert(all_links.size() == 0, "can only call Init once"); this->host_uri = xgboost::collective::GetHostName(); // get information from tracker - return this->ReConnectLinks(); + auto rc = this->ReConnectLinks(); + if (rc.OK()) { + return true; + } + LOG(FATAL) << rc.Report(); + return false; } bool AllreduceBase::Shutdown() { @@ -131,7 +137,11 @@ bool AllreduceBase::Shutdown() { if (tracker_uri == "NULL") return true; // notify tracker rank i have shutdown - xgboost::collective::TCPSocket tracker = this->ConnectTracker(); + xgboost::collective::TCPSocket tracker; + auto rc = this->ConnectTracker(&tracker); + if (!rc.OK()) { + LOG(FATAL) << rc.Report(); + } tracker.Send(xgboost::StringView{"shutdown"}); tracker.Close(); xgboost::system::SocketFinalize(); @@ -146,7 +156,12 @@ void AllreduceBase::TrackerPrint(const std::string &msg) { if (tracker_uri == "NULL") { utils::Printf("%s", msg.c_str()); return; } - xgboost::collective::TCPSocket tracker = this->ConnectTracker(); + xgboost::collective::TCPSocket tracker; + auto rc = this->ConnectTracker(&tracker); + if (!rc.OK()) { + LOG(FATAL) << rc.Report(); + } + tracker.Send(xgboost::StringView{"print"}); tracker.Send(xgboost::StringView{msg}); tracker.Close(); @@ -215,64 +230,67 @@ void AllreduceBase::SetParam(const char *name, const char *val) { } } } + /*! * \brief initialize connection to the tracker * \return a socket that initializes the connection */ -xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const { +[[nodiscard]] xgboost::collective::Result AllreduceBase::ConnectTracker( + xgboost::collective::TCPSocket *out) const { int magic = kMagic; // get information from tracker - xgboost::collective::TCPSocket tracker; + xgboost::collective::TCPSocket &tracker = *out; - int retry = 0; - do { - auto rc = xgboost::collective::Connect( - xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port), - &tracker); - if (rc != std::errc()) { - if (++retry >= connect_retry) { - LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message(); - } else { - LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): [" - << tracker_uri << "]"; -#if defined(_MSC_VER) || defined(__MINGW32__) - Sleep(retry << 1); -#else - sleep(retry << 1); -#endif - continue; - } - } - break; - } while (true); + auto rc = + Connect(xgboost::StringView{tracker_uri}, tracker_port, connect_retry, timeout_sec, &tracker); + if (!rc.OK()) { + return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc)); + } using utils::Assert; - CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic)); - CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic)); - utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure"); - Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); - Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), - "ReConnectLink failure 3"); - CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size()); - return tracker; + if (tracker.SendAll(&magic, sizeof(magic)) != sizeof(magic)) { + return xgboost::collective::Fail("Failed to send the verification number."); + } + if (tracker.RecvAll(&magic, sizeof(magic)) != sizeof(magic)) { + return xgboost::collective::Fail("Failed to recieve the verification number."); + } + if (magic != kMagic) { + return xgboost::collective::Fail("Invalid verification number."); + } + if (tracker.SendAll(&rank, sizeof(rank)) != sizeof(rank)) { + return xgboost::collective::Fail("Failed to send the local rank back to the tracker."); + } + if (tracker.SendAll(&world_size, sizeof(world_size)) != sizeof(world_size)) { + return xgboost::collective::Fail("Failed to send the world size back to the tracker."); + } + if (tracker.Send(xgboost::StringView{task_id}) != task_id.size()) { + return xgboost::collective::Fail("Failed to send the task ID back to the tracker."); + } + + return xgboost::collective::Success(); } /*! * \brief connect to the tracker to fix the the missing links * this function is also used when the engine start up */ -bool AllreduceBase::ReConnectLinks(const char *cmd) { +[[nodiscard]] xgboost::collective::Result AllreduceBase::ReConnectLinks(const char *cmd) { // single node mode if (tracker_uri == "NULL") { rank = 0; world_size = 1; - return true; + return xgboost::collective::Success(); } - try { - xgboost::collective::TCPSocket tracker = this->ConnectTracker(); - LOG(INFO) << "task " << task_id << " connected to the tracker"; - tracker.Send(xgboost::StringView{cmd}); + xgboost::collective::TCPSocket tracker; + auto rc = this->ConnectTracker(&tracker); + if (!rc.OK()) { + return xgboost::collective::Fail("Failed to connect to the tracker.", std::move(rc)); + } + + LOG(INFO) << "task " << task_id << " connected to the tracker"; + tracker.Send(xgboost::StringView{cmd}); + try { // the rank of previous link, next link in ring int prev_rank, next_rank; // the rank of neighbors @@ -334,10 +352,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { tracker.Recv(&hname); Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); - - if (xgboost::collective::Connect( - xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) != - std::errc{}) { + // connect to peer + if (!xgboost::collective::Connect(xgboost::StringView{hname}, hport, connect_retry, + timeout_sec, &r.sock) + .OK()) { num_error += 1; r.sock.Close(); continue; @@ -351,8 +369,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { bool match = false; for (auto & all_link : all_links) { if (all_link.rank == hrank) { - Assert(all_link.sock.IsClosed(), - "Override a link that is active"); + Assert(all_link.sock.IsClosed(), "Override a link that is active"); all_link.sock = std::move(r.sock); match = true; break; @@ -364,10 +381,10 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { "ReConnectLink failure 14"); } while (num_error != 0); // send back socket listening port to tracker - Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), - "ReConnectLink failure 14"); + Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), "ReConnectLink failure 14"); // close connection to tracker tracker.Close(); + // listen to incoming links for (int i = 0; i < num_accept; ++i) { LinkRecord r; @@ -395,7 +412,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { for (auto &all_link : all_links) { utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket"); // set the socket to non-blocking mode, enable TCP keepalive - all_link.sock.SetNonBlock(); + all_link.sock.SetNonBlock(true); all_link.sock.SetKeepAlive(); if (rabit_enable_tcp_no_delay) { all_link.sock.SetNoDelay(); @@ -415,10 +432,11 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { "cannot find prev ring in the link"); Assert(next_rank == -1 || ring_next != nullptr, "cannot find next ring in the link"); - return true; + return xgboost::collective::Success(); } catch (const std::exception& e) { - LOG(WARNING) << "failed in ReconnectLink " << e.what(); - return false; + std::stringstream ss; + ss << "Failed in ReconnectLink " << e.what(); + return xgboost::collective::Fail(ss.str()); } } /*! @@ -523,9 +541,15 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, } } // finish running allreduce - if (finished) break; + if (finished) { + break; + } // select must return - watcher.Poll(timeout_sec); + auto poll_res = watcher.Poll(timeout_sec); + if (!poll_res.OK()) { + LOG(FATAL) << poll_res.Report(); + } + // read data from childs for (int i = 0; i < nlink; ++i) { if (i != parent_index && watcher.CheckRead(links[i].sock)) { @@ -698,7 +722,10 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { // finish running if (finished) break; // select - watcher.Poll(timeout_sec); + auto poll_res = watcher.Poll(timeout_sec); + if (!poll_res.OK()) { + LOG(FATAL) << poll_res.Report(); + } if (in_link == -2) { // probe in-link for (int i = 0; i < nlink; ++i) { @@ -780,8 +807,14 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size, } finished = false; } - if (finished) break; - watcher.Poll(timeout_sec); + if (finished) { + break; + } + + auto poll_res = watcher.Poll(timeout_sec); + if (!poll_res.OK()) { + LOG(FATAL) << poll_res.Report(); + } if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { size_t size = stop_read - read_ptr; size_t start = read_ptr % total_size; @@ -880,8 +913,13 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_, } finished = false; } - if (finished) break; - watcher.Poll(timeout_sec); + if (finished) { + break; + } + auto poll_res = watcher.Poll(timeout_sec); + if (!poll_res.OK()) { + LOG(FATAL) << poll_res.Report(); + } if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read); if (ret != kSuccess) { @@ -953,5 +991,4 @@ AllreduceBase::TryAllreduceRing(void *sendrecvbuf_, (std::min((prank + 1) * step, count) - std::min(prank * step, count)) * type_nbytes); } -} // namespace engine -} // namespace rabit +} // namespace rabit::engine diff --git a/rabit/src/allreduce_base.h b/rabit/src/allreduce_base.h index 67fef0ba6347..f40754273634 100644 --- a/rabit/src/allreduce_base.h +++ b/rabit/src/allreduce_base.h @@ -12,14 +12,16 @@ #ifndef RABIT_ALLREDUCE_BASE_H_ #define RABIT_ALLREDUCE_BASE_H_ +#include #include #include -#include #include -#include -#include "rabit/internal/utils.h" +#include + #include "rabit/internal/engine.h" #include "rabit/internal/socket.h" +#include "rabit/internal/utils.h" +#include "xgboost/collective/result.h" #ifdef RABIT_CXXTESTDEFS_H #define private public @@ -329,13 +331,13 @@ class AllreduceBase : public IEngine { * \brief initialize connection to the tracker * \return a socket that initializes the connection */ - xgboost::collective::TCPSocket ConnectTracker() const; + [[nodiscard]] xgboost::collective::Result ConnectTracker(xgboost::collective::TCPSocket *out) const; /*! * \brief connect to the tracker to fix the the missing links * this function is also used when the engine start up * \param cmd possible command to sent to tracker */ - bool ReConnectLinks(const char *cmd = "start"); + [[nodiscard]] xgboost::collective::Result ReConnectLinks(const char *cmd = "start"); /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure * diff --git a/src/collective/socket.cc b/src/collective/socket.cc index 1ab84cef35d8..78dc3d79be71 100644 --- a/src/collective/socket.cc +++ b/src/collective/socket.cc @@ -1,19 +1,22 @@ -/*! - * Copyright (c) 2022 by XGBoost Contributors +/** + * Copyright 2022-2023 by XGBoost Contributors */ #include "xgboost/collective/socket.h" #include // std::size_t #include // std::int32_t #include // std::memcpy, std::memset +#include // for path #include // std::error_code, std::system_category +#include "rabit/internal/socket.h" // for PollHelper +#include "xgboost/collective/result.h" // for Result + #if defined(__unix__) || defined(__APPLE__) #include // getaddrinfo, freeaddrinfo #endif // defined(__unix__) || defined(__APPLE__) -namespace xgboost { -namespace collective { +namespace xgboost::collective { SockAddress MakeSockAddress(StringView host, in_port_t port) { struct addrinfo hints; std::memset(&hints, 0, sizeof(hints)); @@ -71,7 +74,12 @@ std::size_t TCPSocket::Recv(std::string *p_str) { return bytes; } -std::error_code Connect(SockAddress const &addr, TCPSocket *out) { +[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, + std::chrono::seconds timeout, + xgboost::collective::TCPSocket *out_conn) { + auto addr = MakeSockAddress(xgboost::StringView{host}, port); + auto &conn = *out_conn; + sockaddr const *addr_handle{nullptr}; socklen_t addr_len{0}; if (addr.IsV4()) { @@ -81,14 +89,67 @@ std::error_code Connect(SockAddress const &addr, TCPSocket *out) { addr_handle = reinterpret_cast(&addr.V6().Handle()); addr_len = sizeof(addr.V6().Handle()); } - auto socket = TCPSocket::Create(addr.Domain()); - CHECK_EQ(static_cast(socket.Domain()), static_cast(addr.Domain())); - auto rc = connect(socket.Handle(), addr_handle, addr_len); - if (rc != 0) { - return std::error_code{errno, std::system_category()}; + + conn = TCPSocket::Create(addr.Domain()); + CHECK_EQ(static_cast(conn.Domain()), static_cast(addr.Domain())); + conn.SetNonBlock(true); + + Result last_error; + auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) { + last_error = std::move(err); + LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line + << "): Failed to connect to:" << host << " Error:" << last_error.Report(); + }; + + for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) { + if (attempt > 0) { + LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time."; +#if defined(_MSC_VER) || defined(__MINGW32__) + Sleep(attempt << 1); +#else + sleep(attempt << 1); +#endif + } + + auto rc = connect(conn.Handle(), addr_handle, addr_len); + if (rc != 0) { + auto errcode = system::LastError(); + if (!system::ErrorWouldBlock(errcode)) { + log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}), + __FILE__, __LINE__); + continue; + } + + rabit::utils::PollHelper poll; + poll.WatchWrite(conn); + auto result = poll.Poll(timeout); + if (!result.OK()) { + log_failure(std::move(result), __FILE__, __LINE__); + continue; + } + if (!poll.CheckWrite(conn)) { + log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}), + __FILE__, __LINE__); + continue; + } + result = conn.GetSockError(); + if (!result.OK()) { + log_failure(std::move(result), __FILE__, __LINE__); + continue; + } + + conn.SetNonBlock(false); + return Success(); + + } else { + conn.SetNonBlock(false); + return Success(); + } } - *out = std::move(socket); - return std::make_error_code(std::errc{}); + + std::stringstream ss; + ss << "Failed to connect to " << host << ":" << port; + conn.Close(); + return Fail(ss.str(), std::move(last_error)); } -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc index 571e95f4deb8..ddc73d1f2067 100644 --- a/tests/cpp/collective/test_socket.cc +++ b/tests/cpp/collective/test_socket.cc @@ -1,5 +1,5 @@ -/*! - * Copyright (c) 2022 by XGBoost Contributors +/** + * Copyright 2022-2023 by XGBoost Contributors */ #include #include @@ -10,8 +10,7 @@ #include "../helpers.h" -namespace xgboost { -namespace collective { +namespace xgboost::collective { TEST(Socket, Basic) { system::SocketStartup(); @@ -31,15 +30,16 @@ TEST(Socket, Basic) { TCPSocket client; if (domain == SockDomain::kV4) { auto const& addr = SockAddrV4::Loopback().Addr(); - ASSERT_EQ(Connect(MakeSockAddress(StringView{addr}, port), &client), std::errc{}); + auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client); + ASSERT_TRUE(rc.OK()) << rc.Report(); } else { auto const& addr = SockAddrV6::Loopback().Addr(); - auto rc = Connect(MakeSockAddress(StringView{addr}, port), &client); + auto rc = Connect(StringView{addr}, port, 1, std::chrono::seconds{3}, &client); // some environment (docker) has restricted network configuration. - if (rc == std::error_code{EADDRNOTAVAIL, std::system_category()}) { + if (!rc.OK() && rc.Code() == std::error_code{EADDRNOTAVAIL, std::system_category()}) { GTEST_SKIP_(msg.c_str()); } - ASSERT_EQ(rc, std::errc{}); + ASSERT_EQ(rc, Success()) << rc.Report(); } ASSERT_EQ(client.Domain(), domain); @@ -73,5 +73,4 @@ TEST(Socket, Basic) { system::SocketFinalize(); } -} // namespace collective -} // namespace xgboost +} // namespace xgboost::collective diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 8709589dd1c3..1f42711a20d9 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -20,6 +20,18 @@ def test_rabit_tracker(): assert str(ret) == "test1234" +@pytest.mark.skipif(**tm.not_linux()) +def test_socket_error(): + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1) + tracker.start(1) + env = tracker.worker_envs() + env["DMLC_TRACKER_PORT"] = 0 + env["DMLC_WORKER_CONNECT_RETRY"] = 1 + with pytest.raises(ValueError, match="127.0.0.1:0\n.*refused"): + with xgb.collective.CommunicatorContext(**env): + pass + + def run_rabit_ops(client, n_workers): from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args @@ -58,6 +70,32 @@ def test_rabit_ops(): run_rabit_ops(client, n_workers) +def run_broadcast(client): + from xgboost.dask import _get_dask_config, _get_rabit_args + + workers = tm.get_client_workers(client) + rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client) + + def local_test(worker_id): + with collective.CommunicatorContext(**rabit_args): + res = collective.broadcast(17, 0) + return res + + futures = client.map(local_test, range(len(workers)), workers=workers) + results = client.gather(futures) + np.testing.assert_allclose(np.array(results), 17) + + +@pytest.mark.skipif(**tm.no_dask()) +def test_broadcast(): + from distributed import Client, LocalCluster + + n_workers = 3 + with LocalCluster(n_workers=n_workers) as cluster: + with Client(cluster) as client: + run_broadcast(client) + + @pytest.mark.skipif(**tm.no_ipv6()) @pytest.mark.skipif(**tm.no_dask()) def test_rabit_ops_ipv6():