diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index 4061dbd429f..6727c3a2349 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -24,26 +24,9 @@ ReduceScatter create_reduce_scatter_struct ( ){ uint32_t num_devices = devices.size(); - bool is_linear = topology == ttnn::ccl::Topology::Linear; - - uint32_t device_index = 0; // Initialize device index - std::optional receiver_device_id = std::nullopt; // Initialize receiver device ID - std::optional sender_device_id = std::nullopt; // Initialize sender device ID - for (uint32_t i = 0; i < num_devices; ++i) { - if (devices.at(i) == input_tensor.device()) { - - bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1); - bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0; - device_index = i; - receiver_device_id = is_last_chip_in_clockwise_direction ? - std::nullopt : - std::optional(devices.at((i + 1) % num_devices)->id()); - sender_device_id = is_last_chip_in_counter_clockwise_direction ? - std::nullopt : - std::optional(devices.at((i + num_devices - 1) % num_devices)->id()); - break; - } - } + auto [device_index, sender_device_id, receiver_device_id] = + get_device_index_and_sender_receiver_ids(input_tensor, devices, topology); + TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect"); return ttnn::ReduceScatter{ @@ -145,16 +128,7 @@ Tensor reduce_scatter( const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { - uint32_t num_devices = devices.size(); - if (num_devices == 2){ - topology = ttnn::ccl::Topology::Linear; - } - const auto& input_tensor = input_tensors.at(0); - auto [device_index, sender_device_id, receiver_device_id] = - get_device_index_and_sender_receiver_ids(input_tensor, devices, topology); - - TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect"); return operation::run( ttnn::ccl::reduce_scatter_detail::create_reduce_scatter_struct(