Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rabit] Improved connection handling. #9531

Merged
merged 11 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ formats:
python:
install:
- requirements: doc/requirements.txt
system_packages: true
160 changes: 160 additions & 0 deletions include/xgboost/collective/result.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once

#include <memory> // for unique_ptr
#include <sstream> // for stringstream
#include <stack> // for stack
#include <string> // for string
#include <utility> // for move

namespace xgboost::collective {
namespace detail {
struct ResultImpl {
std::string message;
std::error_code errc{}; // optional for system error.

std::unique_ptr<ResultImpl> prev{nullptr};

ResultImpl() = delete; // must initialize.
ResultImpl(ResultImpl const& that) = delete;
ResultImpl(ResultImpl&& that) = default;
ResultImpl& operator=(ResultImpl const& that) = delete;
ResultImpl& operator=(ResultImpl&& that) = default;

explicit ResultImpl(std::string msg) : message{std::move(msg)} {}
explicit ResultImpl(std::string msg, std::error_code errc)
: message{std::move(msg)}, errc{std::move(errc)} {}
explicit ResultImpl(std::string msg, std::unique_ptr<ResultImpl> prev)
: message{std::move(msg)}, prev{std::move(prev)} {}
explicit ResultImpl(std::string msg, std::error_code errc, std::unique_ptr<ResultImpl> prev)
: message{std::move(msg)}, errc{std::move(errc)}, prev{std::move(prev)} {}

[[nodiscard]] bool operator==(ResultImpl const& that) const noexcept(true) {
if ((prev && !that.prev) || (!prev && that.prev)) {
// one of them doesn't have prev
return false;
}

auto cur_eq = message == that.message && errc == that.errc;
if (prev && that.prev) {
// recursive comparison
auto prev_eq = *prev == *that.prev;
return cur_eq && prev_eq;
}
return cur_eq;
}

[[nodiscard]] std::string Report() {
std::stringstream ss;
ss << "\n- " << this->message;
if (this->errc != std::error_code{}) {
ss << " system error:" << this->errc.message();
}

auto ptr = prev.get();
while (ptr) {
ss << "\n- ";
ss << ptr->message;

if (ptr->errc != std::error_code{}) {
ss << " " << ptr->errc.message();
}
ptr = ptr->prev.get();
}

return ss.str();
}
[[nodiscard]] auto Code() const {
// Find the root error.
std::stack<ResultImpl const*> stack;
auto ptr = this;
while (ptr) {
stack.push(ptr);
if (ptr->prev) {
ptr = ptr->prev.get();
} else {
break;
}
}
while (!stack.empty()) {
auto frame = stack.top();
stack.pop();
if (frame->errc != std::error_code{}) {
return frame->errc;
}
}
return std::error_code{};
}
};
} // namespace detail

/**
* @brief An error type that's easier to handle than throwing dmlc exception. We can
* record and propagate the system error code.
*/
struct Result {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point in the future, we need to propagate the error to Python or other language bindings for delegating the error handling to higher-level frameworks like dask. At the moment, a functional form of error handling is easier to handle than exceptions.

private:
std::unique_ptr<detail::ResultImpl> impl_{nullptr};

public:
Result() noexcept(true) = default;
explicit Result(std::string msg) : impl_{std::make_unique<detail::ResultImpl>(std::move(msg))} {}
explicit Result(std::string msg, std::error_code errc)
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc))} {}
Result(std::string msg, Result&& prev)
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(prev.impl_))} {}
Result(std::string msg, std::error_code errc, Result&& prev)
: impl_{std::make_unique<detail::ResultImpl>(std::move(msg), std::move(errc),
std::move(prev.impl_))} {}

Result(Result const& that) = delete;
Result& operator=(Result const& that) = delete;
Result(Result&& that) = default;
Result& operator=(Result&& that) = default;

[[nodiscard]] bool OK() const noexcept(true) { return !impl_; }
[[nodiscard]] std::string Report() const { return OK() ? "" : impl_->Report(); }
/**
* @brief Return the root system error. This might return success if there's no system error.
*/
[[nodiscard]] auto Code() const { return OK() ? std::error_code{} : impl_->Code(); }
[[nodiscard]] bool operator==(Result const& that) const noexcept(true) {
if (OK() && that.OK()) {
return true;
}
if ((OK() && !that.OK()) || (!OK() && that.OK())) {
return false;
}
return *impl_ == *that.impl_;
}
};

/**
* @brief Return success.
*/
[[nodiscard]] inline auto Success() noexcept(true) { return Result{}; }
/**
* @brief Return failure.
*/
[[nodiscard]] inline auto Fail(std::string msg) { return Result{std::move(msg)}; }
/**
* @brief Return failure with `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc) {
return Result{std::move(msg), std::move(errc)};
}
/**
* @brief Return failure with a previous error.
*/
[[nodiscard]] inline auto Fail(std::string msg, Result&& prev) {
return Result{std::move(msg), std::forward<Result>(prev)};
}
/**
* @brief Return failure with a previous error and a new `errno`.
*/
[[nodiscard]] inline auto Fail(std::string msg, std::error_code errc, Result&& prev) {
return Result{std::move(msg), std::move(errc), std::forward<Result>(prev)};
}
} // namespace xgboost::collective
71 changes: 51 additions & 20 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ using ssize_t = int;

#endif // defined(_WIN32)

#include "xgboost/base.h" // XGBOOST_EXPECT
#include "xgboost/logging.h" // LOG
#include "xgboost/string_view.h" // StringView
#include "xgboost/base.h" // XGBOOST_EXPECT
#include "xgboost/collective/result.h" // for Result
#include "xgboost/logging.h" // LOG
#include "xgboost/string_view.h" // StringView

#if !defined(HOST_NAME_MAX)
#define HOST_NAME_MAX 256 // macos
Expand All @@ -81,6 +82,10 @@ inline std::int32_t LastError() {
#endif
}

[[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
}

#if defined(__GLIBC__)
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
std::int32_t line = __builtin_LINE(),
Expand Down Expand Up @@ -120,15 +125,19 @@ inline std::int32_t CloseSocket(SocketT fd) {
#endif
}

inline bool LastErrorWouldBlock() {
int errsv = LastError();
inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
return errsv == EAGAIN || errsv == EWOULDBLOCK;
return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
#endif // _WIN32
}

inline bool LastErrorWouldBlock() {
int errsv = LastError();
return ErrorWouldBlock(errsv);
}

inline void SocketStartup() {
#if defined(_WIN32)
WSADATA wsa_data;
Expand Down Expand Up @@ -315,23 +324,35 @@ class TCPSocket {
bool IsClosed() const { return handle_ == InvalidSocket(); }

/** \brief get last error code if any */
std::int32_t GetSockError() const {
std::int32_t error = 0;
socklen_t len = sizeof(error);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
return error;
Result GetSockError() const {
std::int32_t optval = 0;
socklen_t len = sizeof(optval);
auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
if (ret != 0) {
auto errc = std::error_code{system::LastError(), std::system_category()};
return Fail("Failed to retrieve socket error.", std::move(errc));
}
if (optval != 0) {
auto errc = std::error_code{optval, std::system_category()};
return Fail("Socket error.", std::move(errc));
}
return Success();
}

/** \brief check if anything bad happens */
bool BadSocket() const {
if (IsClosed()) return true;
std::int32_t err = GetSockError();
if (err == EBADF || err == EINTR) return true;
if (IsClosed()) {
return true;
}
auto err = GetSockError();
if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
return true;
}
return false;
}

void SetNonBlock() {
bool non_block{true};
void SetNonBlock(bool non_block) {
#if defined(_WIN32)
u_long mode = non_block ? 1 : 0;
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
Expand Down Expand Up @@ -530,10 +551,20 @@ class TCPSocket {
};

/**
* \brief Connect to remote address, returns the error code if failed (no exception is
* raised so that we can retry).
* @brief Connect to remote address, returns the error code if failed.
*
* @param host Host IP address.
* @param port Connection port.
* @param retry Number of retries to attempt.
* @param timeout Timeout of each connection attempt.
* @param out_conn Output socket if the connection is successful. Value is invalid and undefined if
* the connection failed.
*
* @return Connection status.
*/
std::error_code Connect(SockAddress const &addr, TCPSocket *out);
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
std::chrono::seconds timeout,
xgboost::collective::TCPSocket *out_conn);

/**
* \brief Get the local host name.
Expand Down
4 changes: 4 additions & 0 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def no_ipv6() -> PytestSkip:
return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."}


def not_linux() -> PytestSkip:
return {"condition": system() != "Linux", "reason": "Linux is required."}


def no_ubjson() -> PytestSkip:
return no_mod("ubjson")

Expand Down
30 changes: 16 additions & 14 deletions rabit/include/rabit/internal/socket.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/*!
* Copyright (c) 2014-2022 by XGBoost Contributors
/**
* Copyright 2014-2023, XGBoost Contributors
* \file socket.h
* \author Tianqi Chen
*/
#ifndef RABIT_INTERNAL_SOCKET_H_
#define RABIT_INTERNAL_SOCKET_H_
#include "xgboost/collective/result.h"
#include "xgboost/collective/socket.h"

#if defined(_WIN32)
Expand Down Expand Up @@ -77,7 +78,7 @@ namespace rabit {
namespace utils {

template <typename PollFD>
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) {
int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) {
#if defined(_WIN32)

#if IS_MINGW()
Expand Down Expand Up @@ -135,41 +136,41 @@ struct PollHelper {
* \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status
*/
inline bool CheckRead(SOCKET fd) const {
[[nodiscard]] bool CheckRead(SOCKET fd) const {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
}
bool CheckRead(xgboost::collective::TCPSocket const &socket) const {
[[nodiscard]] bool CheckRead(xgboost::collective::TCPSocket const& socket) const {
return this->CheckRead(socket.Handle());
}

/*!
* \brief Check if the descriptor is ready for write
* \param fd file descriptor to check status
*/
inline bool CheckWrite(SOCKET fd) const {
[[nodiscard]] bool CheckWrite(SOCKET fd) const {
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
}
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const {
[[nodiscard]] bool CheckWrite(xgboost::collective::TCPSocket const& socket) const {
return this->CheckWrite(socket.Handle());
}
/*!
* \brief perform poll on the set defined, read, write, exception
* \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
* \return
/**
* @brief perform poll on the set defined, read, write, exception
*
* @param timeout specify timeout in seconds. Block if negative.
*/
inline void Poll(std::chrono::seconds timeout) { // NOLINT(*)
[[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout) {
std::vector<pollfd> fdset;
fdset.reserve(fds.size());
for (auto kv : fds) {
fdset.push_back(kv.second);
}
int ret = PollImpl(fdset.data(), fdset.size(), timeout);
if (ret == 0) {
LOG(FATAL) << "Poll timeout";
return xgboost::collective::Fail("Poll timeout.");
} else if (ret < 0) {
LOG(FATAL) << "Failed to poll.";
return xgboost::system::FailWithCode("Poll failed.");
} else {
for (auto& pfd : fdset) {
auto revents = pfd.revents & pfd.events;
Expand All @@ -180,6 +181,7 @@ struct PollHelper {
}
}
}
return xgboost::collective::Success();
}

std::unordered_map<SOCKET, pollfd> fds;
Expand Down
Loading
Loading