Skip to content

Commit

Permalink
Make "Rendezvous" variadic (#2443)
Browse files Browse the repository at this point in the history
The PR allows each Process from the ProcessGrid to contribute more than
one tensor at `ProcessGrid::rendezvous`.
This is the first set of change, will be followed by interpreter updates
to collectives.

Note: The PR does **not** make collectives interpreter variadic. 

Tested: 
1. No existing test failures indicate no change in behavior for ops
using the `rendezvous`
2. The diff is tested with new variadic interpreter for `all_reduce` op.
Will upload the PR soon.


ref: #2099
  • Loading branch information
abhigunj authored Jul 22, 2024
1 parent 70c210d commit 50cdc03
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 52 deletions.
34 changes: 16 additions & 18 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -1130,12 +1131,11 @@ Tensor allGatherOp(const Tensor &operand, int64_t allGatherDim,
llvm::report_fatal_error(invalidArgument(
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));

auto rendezvousResult =
process->rendezvous(*processGroup, channelId, operand);
process->rendezvous(*processGroup, channelId, {operand});
auto groupOperands = llvm::map_to_vector(
*processGroup,
[&](const ProcessId &id) { return rendezvousResult.lookup(id); });
[&](const ProcessId &id) { return rendezvousResult.lookup(id).front(); });

return concatenateOp(groupOperands, allGatherDim, resultType);
}
Expand All @@ -1162,16 +1162,15 @@ Tensor allReduceOp(const Tensor &operand,
llvm::report_fatal_error(invalidArgument(
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));

auto groupOperands =
process->rendezvous(*processGroup, channelId, operand).getSortedTensors();
auto groupOperands = process->rendezvous(*processGroup, channelId, {operand})
.getSortedTensors();

Tensor result(resultType);
for (auto resultIt = result.index_begin(); resultIt != result.index_end();
++resultIt) {
Tensor resultElement;
for (const auto &groupOperand : groupOperands) {
auto groupOperandElement = constant(groupOperand.get(*resultIt));
auto groupOperandElement = constant(groupOperand.front().get(*resultIt));
if (resultElement)
resultElement = eval(computation, {resultElement, groupOperandElement},
/*fallback=*/nullptr, process, &scope)[0]
Expand Down Expand Up @@ -1203,13 +1202,12 @@ Tensor allToAllOp(const Tensor &operand, Axis splitDimension,
llvm::report_fatal_error(invalidArgument(
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));

auto groupOperands =
process->rendezvous(*processGroup, channelId, operand).getSortedTensors();
auto groupOperands = process->rendezvous(*processGroup, channelId, {operand})
.getSortedTensors();

SmallVector<Tensor> scatteredParts;
for (const auto &groupOperand : groupOperands) {
auto splitParts = split(groupOperand, splitCount, splitDimension,
auto splitParts = split(groupOperand.front(), splitCount, splitDimension,
operand.getType().getContext());
for (auto [i, processId] : llvm::enumerate(*processGroup))
if (processId == process->getId())
Expand Down Expand Up @@ -1346,10 +1344,11 @@ Tensor collectiveBroadcastOp(const Tensor &operand,
if (channelId > 0) processGroups = process->crossPartition(replicaGroups);

auto processGroup = processGroups.findGroup(process->getId());
if (processGroup)
return process->rendezvous(*processGroup, channelId, operand)
.lookup((*processGroup)[0]);

if (processGroup) {
return process->rendezvous(*processGroup, channelId, {operand})
.lookup((*processGroup)[0])
.front();
}
return broadcastInDimOp(constant(0.0, operand.getElementType()), {},
operand.getType());
}
Expand All @@ -1371,11 +1370,10 @@ Tensor collectivePermuteOp(const Tensor &operand,
auto from = processGroup[0];
auto to = processGroup[1];
if (from != process->getId() && to != process->getId()) continue;

auto rendezvousResult =
process->rendezvous(processGroup, channelId, operand);
process->rendezvous(processGroup, channelId, {operand});
if (to != process->getId()) continue;
result = rendezvousResult.lookup(from);
result = rendezvousResult.lookup(from).front();
}

if (result) return result;
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/reference/Process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ SmallVector<Tensor> Process::recv(ChannelId channelId) {

RendezvousResult Process::rendezvous(ProcessGroup processGroup,
ChannelId channelId,
const Tensor &operand) {
return grid_->rendezvous(processGroup, channelId, getId(), operand);
ArrayRef<Tensor> operands) {
return grid_->rendezvous(processGroup, channelId, getId(), operands);
}

void Process::send(ArrayRef<Tensor> inputs, ChannelId channelId) {
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/reference/Process.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Process {

/// See `ProcessGrid::rendezvous`.
RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId,
const Tensor &operand);
ArrayRef<Tensor> operands);

/// See `ProcessGrid::send`.
void send(ArrayRef<Tensor> inputs, ChannelId channelId);
Expand Down
41 changes: 28 additions & 13 deletions stablehlo/reference/ProcessGrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,33 @@ std::optional<ProcessGroup> ProcessGroups::findGroup(ProcessId processId) {
// RendezvousResult.
//===----------------------------------------------------------------------===//

RendezvousResult::RendezvousResult(std::map<ProcessId, Tensor> const &result)
: result_(result) {}
RendezvousResult::RendezvousResult(
std::map<ProcessId, SmallVector<Tensor>> const &results)
: results_(results) {}

void RendezvousResult::insert(ProcessId processId, Tensor tensor) {
result_[processId] = tensor;
void RendezvousResult::insert(ProcessId processId,
SmallVector<Tensor> tensors) {
results_[processId] = tensors;
}

Tensor RendezvousResult::lookup(ProcessId processId) const {
auto it = result_.find(processId);
if (it != result_.end()) return it->second;
SmallVector<Tensor> RendezvousResult::lookup(ProcessId processId) const {
auto it = results_.find(processId);
if (it != results_.end()) return it->second;
return {};
}

SmallVector<Tensor> RendezvousResult::getSortedTensors() const {
return llvm::map_to_vector(result_,
SmallVector<SmallVector<Tensor>> RendezvousResult::getSortedTensors() const {
return llvm::map_to_vector(results_,
[](const auto &pair) { return pair.second; });
}

bool RendezvousResult::hasMatchingOperandsCount() const {
auto count = results_.begin()->second.size();
for (const auto &it : results_)
if (count != it.second.size()) return false;
return true;
}

//===----------------------------------------------------------------------===//
// ThreadSafeMap.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -223,16 +232,19 @@ SmallVector<Tensor> ProcessGrid::recv(ChannelId channelId,
RendezvousResult ProcessGrid::rendezvous(ProcessGroup processGroup,
ChannelId channelId,
ProcessId processId,
const Tensor &operand) {
ArrayRef<Tensor> operands) {
// Process wait/notify logic below doesn't work for single process.
if (processGroup.size() == 1)
return RendezvousResult({std::pair{processId, operand}});
if (processGroup.size() == 1) {
std::map<ProcessId, SmallVector<Tensor>> results;
results[processId] = SmallVector<Tensor>(operands);
return RendezvousResult(results);
}

std::pair<ProcessGroup, ChannelId> channelKey(processGroup, channelId);
auto &state = channels_[channelKey];

std::unique_lock<std::mutex> lock(state.mutex);
state.values[processId] = operand;
state.values[processId] = SmallVector<Tensor>(operands);
state.useCount++;

// After each process contributes, wait for the last process to notify.
Expand All @@ -248,6 +260,9 @@ RendezvousResult ProcessGrid::rendezvous(ProcessGroup processGroup,

state.useCount--;

if (!state.result.hasMatchingOperandsCount())
llvm::report_fatal_error("Mismatched number of operands per process");

return state.useCount > 0 ? state.result : std::move(state.result);
}

Expand Down
40 changes: 22 additions & 18 deletions stablehlo/reference/ProcessGrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,33 @@ namespace stablehlo {
struct ProcessId;

/// Represents a result of a `ProcessGrid::rendezvous` where multiple processes
/// synchronize at a barrier and contribute a Tensor each.
/// This class is pretty much a map from ProcessId to Tensor, with the
/// synchronize at a barrier and contribute same number of Tensors.
/// This class is pretty much a map from ProcessId to set of Tensors, with the
/// map-like API.
class RendezvousResult {
public:
RendezvousResult() = default;
RendezvousResult(std::map<ProcessId, Tensor> const &result);
RendezvousResult(std::map<ProcessId, SmallVector<Tensor>> const &results);

/// Iterates through the (ProcessId, Tensor) map entires and returns a vector
/// of Tensors sorted by ProcessId--(replicaId, partitionId) pair--in
/// lexicographical order.
SmallVector<Tensor> getSortedTensors() const;
/// Iterates through the (ProcessId, SmallVector<Tensor>) map entires and
/// returns a vector of Tensors sorted by ProcessId--(replicaId, partitionId)
/// pair--in lexicographical order.
SmallVector<SmallVector<Tensor>> getSortedTensors() const;

/// Inserts `tensor` into the map using the key `processId`.
void insert(ProcessId processId, Tensor tensor);
/// Inserts `SmallVector<tensor>` into the map using the key `processId`.
void insert(ProcessId processId, SmallVector<Tensor> tensor);

/// Iterates through the map and returns the value associated with the key
/// `processId`. If key is not found, return an empty `Tensor`.
Tensor lookup(ProcessId processId) const;
/// `processId`. If key is not found, return an empty `SmallVector<Tensor>`.
SmallVector<Tensor> lookup(ProcessId processId) const;

/// Iterates through the (ProcessId, SmallVector<Tensor>) map entires and
/// return true if all processes contributed same number of operand Tensors
bool hasMatchingOperandsCount() const;

private:
/// Internal map representation of the result of `ProcessGrid::rendezvous`.
std::map<ProcessId, Tensor> result_;
std::map<ProcessId, SmallVector<Tensor>> results_;
};

namespace detail {
Expand All @@ -72,7 +76,7 @@ struct RendezvousState {
std::mutex mutex;

/// Internal storage used to store data contributed by the processes.
std::map<ProcessId, Tensor> values;
std::map<ProcessId, SmallVector<Tensor>> values;

/// Internal state management counter which counts the number of processes
/// that contributed already.
Expand Down Expand Up @@ -245,12 +249,12 @@ class ProcessGrid {
/// underlying StableHLO programs or bugs in the StableHLO interpreter don't
/// deadlock the interpreter.
///
/// At the barrier, each StableHLO process contributes a tensor, and these
/// tensors are accumulated in `RendezvousResult` whose shared pointer is
/// returned to all callers once the barrier has been reached by all StableHLO
/// processes.
/// At the barrier, each StableHLO process contribute any number of tensors,
/// and these tensors are accumulated in `RendezvousResult` whose shared
/// pointer is returned to all callers once the barrier has been reached by
/// all StableHLO processes.
RendezvousResult rendezvous(ProcessGroup processGroup, ChannelId channelId,
ProcessId processId, const Tensor &operand);
ProcessId processId, ArrayRef<Tensor> operands);

/// Sends `inputs` to a channel with `channelId`.
/// The channel with `channelId` is emptied before the receiving process can
Expand Down

0 comments on commit 50cdc03

Please sign in to comment.