Skip to content

Commit

Permalink
#0: Rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Oct 29, 2024
1 parent cc7efe4 commit 5a92890
Showing 1 changed file with 3 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,9 @@ ReduceScatter create_reduce_scatter_struct (
){
uint32_t num_devices = devices.size();

bool is_linear = topology == ttnn::ccl::Topology::Linear;

uint32_t device_index = 0; // Initialize device index
std::optional<chip_id_t> receiver_device_id = std::nullopt; // Initialize receiver device ID
std::optional<chip_id_t> sender_device_id = std::nullopt; // Initialize sender device ID
for (uint32_t i = 0; i < num_devices; ++i) {
if (devices.at(i) == input_tensor.device()) {

bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0;
device_index = i;
receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id());
sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id());
break;
}
}
auto [device_index, sender_device_id, receiver_device_id] =
get_device_index_and_sender_receiver_ids(input_tensor, devices, topology);

TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect");

return ttnn::ReduceScatter{
Expand Down Expand Up @@ -145,16 +128,7 @@ Tensor reduce_scatter(
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {

uint32_t num_devices = devices.size();
if (num_devices == 2){
topology = ttnn::ccl::Topology::Linear;
}

const auto& input_tensor = input_tensors.at(0);
auto [device_index, sender_device_id, receiver_device_id] =
get_device_index_and_sender_receiver_ids(input_tensor, devices, topology);

TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect");

return operation::run(
ttnn::ccl::reduce_scatter_detail::create_reduce_scatter_struct(
Expand Down

0 comments on commit 5a92890

Please sign in to comment.