Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support bitwise allreduce operations in the communicator #8623

Merged
merged 2 commits into from
Dec 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions plugin/federated/federated.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ enum ReduceOperation {
MAX = 0;
MIN = 1;
SUM = 2;
BITWISE_AND = 3;
BITWISE_OR = 4;
BITWISE_XOR = 5;
}

message AllreduceRequest {
Expand Down
3 changes: 3 additions & 0 deletions python-package/xgboost/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ class Op(IntEnum):
MAX = 0
MIN = 1
SUM = 2
BITWISE_AND = 3
BITWISE_OR = 4
BITWISE_XOR = 5


def allreduce( # pylint:disable=invalid-name
Expand Down
4 changes: 3 additions & 1 deletion rabit/include/rabit/internal/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ enum OpType {
kMax = 0,
kMin = 1,
kSum = 2,
kBitwiseOR = 3
kBitwiseAND = 3,
kBitwiseOR = 4,
kBitwiseXOR = 5,
};
/*!\brief enum of supported data types */
enum DataType {
Expand Down
14 changes: 14 additions & 0 deletions rabit/include/rabit/internal/rabit-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,27 @@ struct Sum {
dst += src;
}
};
struct BitAND {
static const engine::mpi::OpType kType = engine::mpi::kBitwiseAND;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst &= src;
}
};
struct BitOR {
static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst |= src;
}
};
struct BitXOR {
static const engine::mpi::OpType kType = engine::mpi::kBitwiseXOR;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst ^= src;
}
};
template <typename OP, typename DType>
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &) {
const DType *src = static_cast<const DType *>(src_);
Expand Down
10 changes: 10 additions & 0 deletions rabit/include/rabit/rabit.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,21 @@ struct Min;
* \brief sum reduction operator
*/
struct Sum;
/*!
* \class rabit::op::BitAND
* \brief bitwise AND reduction operator
*/
struct BitAND;
/*!
* \class rabit::op::BitOR
* \brief bitwise OR reduction operator
*/
struct BitOR;
/*!
* \class rabit::op::BitXOR
* \brief bitwise XOR reduction operator
*/
struct BitXOR;
} // namespace op
/*!
* \brief initializes rabit, call this once at the beginning of your program
Expand Down
36 changes: 35 additions & 1 deletion rabit/src/rabit_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,36 @@ struct FHelper {
}
};

template<typename DType>
struct FHelper<op::BitAND, DType> {
static void
Allreduce(DType *,
size_t ,
void (*)(void *arg),
void *) {
utils::Error("DataType does not support bitwise AND operation");
}
};

template<typename DType>
struct FHelper<op::BitOR, DType> {
static void
Allreduce(DType *,
size_t ,
void (*)(void *arg),
void *) {
utils::Error("DataType does not support bitwise or operation");
utils::Error("DataType does not support bitwise OR operation");
}
};

template<typename DType>
struct FHelper<op::BitXOR, DType> {
static void
Allreduce(DType *,
size_t ,
void (*)(void *arg),
void *) {
utils::Error("DataType does not support bitwise XOR operation");
}
};

Expand Down Expand Up @@ -111,12 +133,24 @@ void Allreduce(void *sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kBitwiseAND:
Allreduce<op::BitAND>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kBitwiseOR:
Allreduce<op::BitOR>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
case kBitwiseXOR:
Allreduce<op::BitXOR>
(sendrecvbuf,
count, enum_dtype,
prepare_fun, prepare_arg);
return;
default: utils::Error("unknown enum_op");
}
}
Expand Down
9 changes: 8 additions & 1 deletion src/collective/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@ inline std::size_t GetTypeSize(DataType data_type) {
}

/** @brief Defines the reduction operation. */
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
enum class Operation {
kMax = 0,
kMin = 1,
kSum = 2,
kBitwiseAND = 3,
kBitwiseOR = 4,
kBitwiseXOR = 5
};

class DeviceCommunicator;

Expand Down
28 changes: 28 additions & 0 deletions src/collective/in_memory_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@ class AllreduceFunctor {
}

private:
template <class T, std::enable_if_t<std::is_integral<T>::value>* = nullptr>
void AccumulateBitwise(T* buffer, T const* input, std::size_t size,
Operation reduce_operation) const {
switch (reduce_operation) {
case Operation::kBitwiseAND:
std::transform(buffer, buffer + size, input, buffer, std::bit_and<T>());
break;
case Operation::kBitwiseOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_or<T>());
break;
case Operation::kBitwiseXOR:
std::transform(buffer, buffer + size, input, buffer, std::bit_xor<T>());
break;
default:
throw std::invalid_argument("Invalid reduce operation");
}
}

template <class T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
void AccumulateBitwise(T*, T const*, std::size_t, Operation) const {
LOG(FATAL) << "Floating point types do not support bitwise operations.";
}

template <class T>
void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const {
switch (reduce_operation) {
Expand All @@ -44,6 +67,11 @@ class AllreduceFunctor {
case Operation::kSum:
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
AccumulateBitwise(buffer, input, size, reduce_operation);
break;
default:
throw std::invalid_argument("Invalid reduce operation");
}
Expand Down
33 changes: 30 additions & 3 deletions src/collective/rabit_communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,33 @@ class RabitCommunicator : public Communicator {
void Print(const std::string &message) override { rabit::TrackerPrint(message); }

protected:
void Shutdown() override {
rabit::Finalize();
}
void Shutdown() override { rabit::Finalize(); }

private:
template <typename DType, std::enable_if_t<std::is_integral<DType>::value> * = nullptr>
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kBitwiseAND:
rabit::Allreduce<rabit::op::BitAND, DType>(static_cast<DType *>(send_receive_buffer),
count);
break;
case Operation::kBitwiseOR:
rabit::Allreduce<rabit::op::BitOR, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kBitwiseXOR:
rabit::Allreduce<rabit::op::BitXOR, DType>(static_cast<DType *>(send_receive_buffer),
count);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}

template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
LOG(FATAL) << "Floating point types do not support bitwise operations.";
}

template <typename DType>
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
Expand All @@ -113,6 +135,11 @@ class RabitCommunicator : public Communicator {
case Operation::kSum:
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
DoBitwiseAllReduce<DType>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
Expand Down
93 changes: 73 additions & 20 deletions tests/cpp/collective/test_in_memory_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <dmlc/parameter.h>
#include <gtest/gtest.h>

#include <bitset>
#include <thread>

#include "../../../src/collective/in_memory_communicator.h"
Expand All @@ -13,7 +14,37 @@ namespace collective {

class InMemoryCommunicatorTest : public ::testing::Test {
public:
static void VerifyAllreduce(int rank) {
static void Verify(void (*function)(int)) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(function, rank);
}
for (auto &thread : threads) {
thread.join();
}
}

static void AllreduceMax(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax);
int expected[] = {3, 4, 5, 6, 7};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}
}

static void AllreduceMin(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMin);
int expected[] = {1, 2, 3, 4, 5};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}
}

static void AllreduceSum(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
Expand All @@ -23,7 +54,35 @@ class InMemoryCommunicatorTest : public ::testing::Test {
}
}

static void VerifyBroadcast(int rank) {
static void AllreduceBitwiseAND(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
std::bitset<2> original(rank);
auto buffer = original.to_ulong();
comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseAND);
EXPECT_EQ(buffer, 0UL);
}

static void AllreduceBitwiseOR(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
std::bitset<2> original(rank);
auto buffer = original.to_ulong();
comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseOR);
std::bitset<2> actual(buffer);
std::bitset<2> expected{0b11};
EXPECT_EQ(actual, expected);
}

static void AllreduceBitwiseXOR(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
std::bitset<3> original(rank * 2);
auto buffer = original.to_ulong();
comm.AllReduce(&buffer, 1, DataType::kUInt32, Operation::kBitwiseXOR);
std::bitset<3> actual(buffer);
std::bitset<3> expected{0b110};
EXPECT_EQ(actual, expected);
}

static void Broadcast(int rank) {
InMemoryCommunicator comm{kWorldSize, rank};
if (rank == 0) {
std::string buffer{"hello"};
Expand Down Expand Up @@ -88,25 +147,19 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
EXPECT_TRUE(comm.IsDistributed());
}

TEST_F(InMemoryCommunicatorTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyAllreduce, rank));
}
for (auto &thread : threads) {
thread.join();
}
}
TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }

TEST_F(InMemoryCommunicatorTest, Broadcast) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyBroadcast, rank));
}
for (auto &thread : threads) {
thread.join();
}
}
TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }

TEST_F(InMemoryCommunicatorTest, AllreduceSum) { Verify(&AllreduceSum); }

TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseAND) { Verify(&AllreduceBitwiseAND); }

TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseOR) { Verify(&AllreduceBitwiseOR); }

TEST_F(InMemoryCommunicatorTest, AllreduceBitwiseXOR) { Verify(&AllreduceBitwiseXOR); }

TEST_F(InMemoryCommunicatorTest, Broadcast) { Verify(&Broadcast); }

} // namespace collective
} // namespace xgboost