Skip to content

Commit

Permalink
#15060: moved new all gather to separate all gather v2 codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Nov 14, 2024
1 parent 08867ea commit 2ea46ad
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 21 deletions.
3 changes: 2 additions & 1 deletion ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core_new.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather_v2/device/multi_core/all_gather_op_multi_core_new.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather_v2/device/all_gather_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/ccl/all_gather/all_gather.hpp"
#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp"
#include "ttnn/operations/ccl/all_gather_v2/device/all_gather_op.hpp"
#include "ttnn/distributed/types.hpp"

namespace ttnn::operations::ccl {
Expand All @@ -15,7 +15,7 @@ ttnn::Tensor ExecuteAllGather::invoke(const ttnn::Tensor& input_tensor,
const std::optional<size_t> num_workers,
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) {
return ttnn::operations::ccl::all_gather(
return ttnn::operations::ccl::all_gather_v2(
input_tensor, dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
}

Expand All @@ -29,7 +29,7 @@ ttnn::Tensor ExecuteAllGather::invoke(
const std::optional<size_t> num_workers,
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) {
return ttnn::operations::ccl::all_gather(
return ttnn::operations::ccl::all_gather_v2(
input_tensor, dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ std::vector<Tensor> AllGather::create_output_tensors(const std::vector<Tensor> &
}

operation::ProgramWithCallbacks AllGather::create_program(const std::vector<Tensor> & input_tensors, std::vector<Tensor> &output_tensors) const {
return all_gather_multi_core_with_workers_new(input_tensors[0], output_tensors[0], this->dim, this->num_links, this->ring_size, this->ring_index, this->receiver_device_id, this->sender_device_id, this->topology, this->user_defined_num_workers, this->user_defined_num_buffers_per_channel);
return all_gather_multi_core_with_workers(input_tensors[0], output_tensors[0], this->dim, this->num_links, this->ring_size, this->ring_index, this->receiver_device_id, this->sender_device_id, this->topology, this->user_defined_num_workers, this->user_defined_num_buffers_per_channel);
}

namespace operations {
Expand Down
13 changes: 0 additions & 13 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
const std::optional<size_t> user_defined_num_buffers_per_channel,
std::optional<experimental::ccl::AllGatherFusedOpSignaler>& fused_op_signaler,
const CoreCoord core_grid_offset = CoreCoord(0, 0));
operation::ProgramWithCallbacks all_gather_multi_core_with_workers_new(
const Tensor& input_tensor,
Tensor& output_tensor,
const uint32_t dim,
const uint32_t num_links,
const uint32_t ring_size,
const uint32_t ring_index,
const std::optional<chip_id_t> receiver_device_id,
const std::optional<chip_id_t> sender_device_id,
ccl::Topology topology,
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel);



namespace operations {
Expand Down
203 changes: 203 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/all_gather_v2/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/operations/ccl/all_gather_v2/device/all_gather_op.hpp"
#include "ttnn/operations/math.hpp"

#include "tt_metal/host_api.hpp"

#include "ttnn/tensor/tensor_utils.hpp"

#include "eth_l1_address_map.h"

namespace ttnn {
namespace ccl{
namespace all_gather_detail{

AllGatherV2 create_all_gather_struct(
const Tensor& input_tensor,
const uint32_t dim,
const uint32_t num_links,
const std::optional<MemoryConfig>& memory_config,
const std::vector<Device*>& devices,
const ttnn::ccl::Topology topology
) {
uint32_t num_devices = devices.size();
auto [device_index, sender_device_id, receiver_device_id] =
get_device_index_and_sender_receiver_ids(input_tensor, devices, topology);

return ttnn::AllGatherV2{
dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config()), topology};
}
} // namespace all_gather_v2_detail
} // namespace ccl

void AllGatherV2::validate(const std::vector<Tensor> &input_tensors) const {
TT_FATAL(input_tensors.size() == 1, "Error, Input tensor size should be 1 but has {}", input_tensors.size());
const auto& input_tensor = input_tensors[0];
const auto& layout = input_tensors[0].get_layout();
const auto& dtype = input_tensors[0].get_dtype();
const auto& page_size = input_tensors[0].buffer()->page_size();
TT_FATAL(page_size % input_tensors[0].buffer()->alignment() == 0, "All Gather currently requires aligned pages");

// TODO: This can be removed by passing two page sizes, actual and aligned to be used for address offsets
// Buffer sizes also need to take this aligned page size into consideration
// TODO: Validate ring
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to all_gather need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr , "Operands to all_gather need to be allocated in buffers on device!");
TT_FATAL(this->num_links > 0, "Error, num_links should be more than 0 but has {}", this->num_links);
TT_FATAL(this->num_links <= input_tensor.device()->compute_with_storage_grid_size().y, "Worker cores used by links are parallelizaed over rows");
TT_FATAL(this->receiver_device_id.has_value() || this->sender_device_id.has_value(), "Error, All-gather was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid all-gather configuration. The input mesh tensor or all-gather arguments may be incorrect");

TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED ||
input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED ||
input_tensor.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED ||
input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED,
"Unsupported memory layout {}.", input_tensor.memory_config().memory_layout);

// Sharding Config checks
bool input_sharded = input_tensor.is_sharded();
if (input_sharded) {
// TODO(snijjar)
}
}

std::vector<ttnn::SimpleShape> AllGatherV2::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
auto shape = input_tensors[0].get_padded_shape(); // TODO: Replace with get_logical_shape()
shape[this->dim] *= this->ring_size;
return std::vector<ttnn::SimpleShape>(input_tensors.size(), shape);
}

std::vector<Tensor> AllGatherV2::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor = input_tensors[0];
if(this->output_mem_config.is_sharded()) {
return {create_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
input_tensor.get_dtype(),
input_tensor.get_layout(),
input_tensor.device(),
this->output_mem_config,
input_tensor.get_tile()
)};
} else {
return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config, input_tensor.get_tile());
}
}

operation::ProgramWithCallbacks AllGatherV2::create_program(const std::vector<Tensor> & input_tensors, std::vector<Tensor> &output_tensors) const {
return all_gather_multi_core_with_workers_new(input_tensors[0], output_tensors[0], this->dim, this->num_links, this->ring_size, this->ring_index, this->receiver_device_id, this->sender_device_id, this->topology);
}

namespace operations {
namespace ccl {

Tensor all_gather_v2(
const Tensor& input_tensor,
const uint32_t dim,
const uint32_t num_links,
const std::optional<MemoryConfig>& memory_config,
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel,
const ttnn::ccl::Topology topology) {

tt::log_info(tt::LogOp, "DEBUG: all_gather is called");

TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "all_gather op is only supported for Fast Dispatch");
auto devices = input_tensor.get_workers();
uint32_t num_devices = devices.size();
TT_FATAL(num_devices > 1, "all_gather op will only work for num_devices > 1, but has {}", num_devices);
ttnn::ccl::Topology ccl_topology = topology;

if (num_devices == 2){
ccl_topology = ttnn::ccl::Topology::Linear;
}
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};

auto programs = std::vector<Program>(devices.size());
auto program_ptrs = std::vector<Program*>(devices.size());
std::transform(programs.begin(), programs.end(), program_ptrs.begin(), [](auto& program) { return &program; });
TT_FATAL(num_links == 1, "all_gather op is only supported for num_links == 1, but has {}", num_links);
tt::log_info(tt::LogOp, "DEBUG: creating line_fabric with num devices: {}, num links: {}", devices.size(), num_links);
auto line_fabric = ttnn::ccl::EdmLineFabricOpInterface(devices, program_ptrs, num_links);
tt::log_info(tt::LogOp, "DEBUG: line_fabric is created");

operation::launch_op(
[dim, num_links, memory_config, devices, ccl_topology](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {

const auto& input_tensor = input_tensors.at(0);

return operation::run(
ttnn::ccl::all_gather_detail::create_all_gather_struct(input_tensor, dim, num_links, memory_config, devices, ccl_topology),
{input_tensor});
},
{input_tensor},
output_tensors);
return output_tensors.at(0);
}

Tensor all_gather_v2(
const Tensor& input_tensor,
const uint32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links,
const std::optional<MemoryConfig>& memory_config,
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel,
const ttnn::ccl::Topology topology) {

tt::log_info(tt::LogOp, "DEBUG: all_gather with cluster_axis is called");

TT_FATAL(topology == ttnn::ccl::Topology::Linear, "This all_gather API with cluster_axis is currently supported only for the Linear topology");
const auto mesh_view = mesh_device.get_view();
std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols();

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};

operation::launch_op(
[dim, num_links, memory_config, mesh_view, cluster_axis, num_devices, topology](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {

const auto& input_device_tensor = input_tensors.at(0);

const auto coordinate = mesh_view->find_device(input_device_tensor.device()->id());
const auto view_index = (cluster_axis == 0) ? coordinate.col : coordinate.row;
const auto device_index = (cluster_axis == 0) ? coordinate.row : coordinate.col;

auto get_chip_id = [&](std::size_t line_index) -> std::optional<chip_id_t> {
auto new_coord = coordinate;
if (cluster_axis == 0) {
new_coord.row = line_index % num_devices;
} else {
new_coord.col = line_index % num_devices;
}
return mesh_view->find_device_id(new_coord);
};

bool is_last_chip_in_clockwise_direction = device_index == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = device_index == 0;
auto receiver_device_id = is_last_chip_in_clockwise_direction ? std::nullopt : get_chip_id(device_index + 1);
auto sender_device_id = is_last_chip_in_counter_clockwise_direction ? std::nullopt : get_chip_id(device_index + num_devices - 1);

return operation::run(
ttnn::AllGatherV2{
dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_device_tensor.memory_config()), topology},
{input_device_tensor});
},
{input_tensor},
output_tensors);
return output_tensors.at(0);

}


} // namespace ccl
} // namespace operations

} // namespace ttnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <cstdint>
#include "common/core_coord.hpp"
#include "impl/buffers/buffer.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "ttnn/operations/ccl/ccl_host_datastructures.hpp"
#include "ttnn/operations/ccl/ccl_common.hpp"
#include "ttnn/operations/ccl/ccl_op_fusion.hpp"


#include "ttnn/run_operation.hpp"

#include <optional>
#include <vector>
#include <algorithm>

namespace ttnn {

using ccl::EriscDatamoverBuilder;

struct AllGatherV2 {
const uint32_t dim;
const uint32_t num_links;
const uint32_t ring_size;
const uint32_t ring_index;
const std::optional<chip_id_t> receiver_device_id;
const std::optional<chip_id_t> sender_device_id;
const MemoryConfig output_mem_config;
const ccl::Topology topology;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
};

namespace ccl{
namespace all_gather_v2_detail{
AllGatherV2 create_all_gather_struct(
const Tensor& input_tensor,
const uint32_t dim,
const uint32_t num_links,
const std::optional<MemoryConfig>& memory_config,
const std::vector<Device*>& devices,
const ccl::Topology topology
);
} // namespace all_gather_detail
} // namespace ccl

// All Gather Variants
operation::ProgramWithCallbacks all_gather_multi_core_with_workers_new(
const Tensor& input_tensor,
Tensor& output_tensor,
const uint32_t dim,
const uint32_t num_links,
const uint32_t ring_size,
const uint32_t ring_index,
const std::optional<chip_id_t> receiver_device_id,
const std::optional<chip_id_t> sender_device_id,
ccl::Topology topology);



namespace operations {
namespace ccl {

Tensor all_gather_v2(
const Tensor& input_tensor,
const uint32_t dim,
const uint32_t num_links = 1,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> user_defined_num_workers = std::nullopt,
const std::optional<size_t> user_defined_num_buffers_per_channel = std::nullopt,
const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring);

Tensor all_gather_v2(
const Tensor& input_tensor,
const uint32_t dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links = 1,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> user_defined_num_workers = std::nullopt,
const std::optional<size_t> user_defined_num_buffers_per_channel = std::nullopt,
const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Linear);

} // namespace ccl
} // namespace operations

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_new(
const uint32_t ring_index,
const std::optional<chip_id_t> receiver_device_id,
const std::optional<chip_id_t> sender_device_id,
ccl::Topology topology,
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel) {
ccl::Topology topology) {

// Sleep for 5 * ring_index seconds (for DEBUG only)
// std::chrono::seconds sleep_duration(5 * ring_index);
Expand Down
12 changes: 12 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,18 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder(
EriscDataMoverTerminationMode termination_mode);


// Add this declaration after the existing generate_slice_sequence_on_dim declaration (around line 235)
std::vector<TensorSlice> generate_slice_sequence_on_dim_v2(
TensorSlice::ords_t tensor_shape,
TensorSlice::ords_t worker_slice_shape,
TensorSlice::ords_t worker_slice_offset,
std::size_t fracture_dim,
std::size_t num_slices,
std::int64_t start_slice_index,
std::int64_t end_slice_index_exclusive,
std::size_t worker_index
);

class GenericWrappedTensorSlicer {
public:
GenericWrappedTensorSlicer(
Expand Down

0 comments on commit 2ea46ad

Please sign in to comment.