-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#15060: moved new all gather to separate all gather v2 codebase
- Loading branch information
1 parent
08867ea
commit 2ea46ad
Showing
8 changed files
with
320 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
203 changes: 203 additions & 0 deletions
203
ttnn/cpp/ttnn/operations/ccl/all_gather_v2/device/all_gather_op.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ttnn/operations/ccl/all_gather_v2/device/all_gather_op.hpp" | ||
#include "ttnn/operations/math.hpp" | ||
|
||
#include "tt_metal/host_api.hpp" | ||
|
||
#include "ttnn/tensor/tensor_utils.hpp" | ||
|
||
#include "eth_l1_address_map.h" | ||
|
||
namespace ttnn { | ||
namespace ccl{ | ||
namespace all_gather_detail{ | ||
|
||
AllGatherV2 create_all_gather_struct( | ||
const Tensor& input_tensor, | ||
const uint32_t dim, | ||
const uint32_t num_links, | ||
const std::optional<MemoryConfig>& memory_config, | ||
const std::vector<Device*>& devices, | ||
const ttnn::ccl::Topology topology | ||
) { | ||
uint32_t num_devices = devices.size(); | ||
auto [device_index, sender_device_id, receiver_device_id] = | ||
get_device_index_and_sender_receiver_ids(input_tensor, devices, topology); | ||
|
||
return ttnn::AllGatherV2{ | ||
dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_tensor.memory_config()), topology}; | ||
} | ||
} // namespace all_gather_v2_detail | ||
} // namespace ccl | ||
|
||
void AllGatherV2::validate(const std::vector<Tensor> &input_tensors) const { | ||
TT_FATAL(input_tensors.size() == 1, "Error, Input tensor size should be 1 but has {}", input_tensors.size()); | ||
const auto& input_tensor = input_tensors[0]; | ||
const auto& layout = input_tensors[0].get_layout(); | ||
const auto& dtype = input_tensors[0].get_dtype(); | ||
const auto& page_size = input_tensors[0].buffer()->page_size(); | ||
TT_FATAL(page_size % input_tensors[0].buffer()->alignment() == 0, "All Gather currently requires aligned pages"); | ||
|
||
// TODO: This can be removed by passing two page sizes, actual and aligned to be used for address offsets | ||
// Buffer sizes also need to take this aligned page size into consideration | ||
// TODO: Validate ring | ||
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to all_gather need to be on device!"); | ||
TT_FATAL(input_tensor.buffer() != nullptr , "Operands to all_gather need to be allocated in buffers on device!"); | ||
TT_FATAL(this->num_links > 0, "Error, num_links should be more than 0 but has {}", this->num_links); | ||
TT_FATAL(this->num_links <= input_tensor.device()->compute_with_storage_grid_size().y, "Worker cores used by links are parallelizaed over rows"); | ||
TT_FATAL(this->receiver_device_id.has_value() || this->sender_device_id.has_value(), "Error, All-gather was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid all-gather configuration. The input mesh tensor or all-gather arguments may be incorrect"); | ||
|
||
TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED || | ||
input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED || | ||
input_tensor.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED || | ||
input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, | ||
"Unsupported memory layout {}.", input_tensor.memory_config().memory_layout); | ||
|
||
// Sharding Config checks | ||
bool input_sharded = input_tensor.is_sharded(); | ||
if (input_sharded) { | ||
// TODO(snijjar) | ||
} | ||
} | ||
|
||
std::vector<ttnn::SimpleShape> AllGatherV2::compute_output_shapes(const std::vector<Tensor> &input_tensors) const { | ||
auto shape = input_tensors[0].get_padded_shape(); // TODO: Replace with get_logical_shape() | ||
shape[this->dim] *= this->ring_size; | ||
return std::vector<ttnn::SimpleShape>(input_tensors.size(), shape); | ||
} | ||
|
||
std::vector<Tensor> AllGatherV2::create_output_tensors(const std::vector<Tensor> &input_tensors) const { | ||
const auto& input_tensor = input_tensors[0]; | ||
if(this->output_mem_config.is_sharded()) { | ||
return {create_device_tensor( | ||
this->compute_output_shapes(input_tensors).at(0), | ||
input_tensor.get_dtype(), | ||
input_tensor.get_layout(), | ||
input_tensor.device(), | ||
this->output_mem_config, | ||
input_tensor.get_tile() | ||
)}; | ||
} else { | ||
return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config, input_tensor.get_tile()); | ||
} | ||
} | ||
|
||
operation::ProgramWithCallbacks AllGatherV2::create_program(const std::vector<Tensor> & input_tensors, std::vector<Tensor> &output_tensors) const { | ||
return all_gather_multi_core_with_workers_new(input_tensors[0], output_tensors[0], this->dim, this->num_links, this->ring_size, this->ring_index, this->receiver_device_id, this->sender_device_id, this->topology); | ||
} | ||
|
||
namespace operations { | ||
namespace ccl { | ||
|
||
Tensor all_gather_v2( | ||
const Tensor& input_tensor, | ||
const uint32_t dim, | ||
const uint32_t num_links, | ||
const std::optional<MemoryConfig>& memory_config, | ||
const std::optional<size_t> user_defined_num_workers, | ||
const std::optional<size_t> user_defined_num_buffers_per_channel, | ||
const ttnn::ccl::Topology topology) { | ||
|
||
tt::log_info(tt::LogOp, "DEBUG: all_gather is called"); | ||
|
||
TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "all_gather op is only supported for Fast Dispatch"); | ||
auto devices = input_tensor.get_workers(); | ||
uint32_t num_devices = devices.size(); | ||
TT_FATAL(num_devices > 1, "all_gather op will only work for num_devices > 1, but has {}", num_devices); | ||
ttnn::ccl::Topology ccl_topology = topology; | ||
|
||
if (num_devices == 2){ | ||
ccl_topology = ttnn::ccl::Topology::Linear; | ||
} | ||
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; | ||
|
||
auto programs = std::vector<Program>(devices.size()); | ||
auto program_ptrs = std::vector<Program*>(devices.size()); | ||
std::transform(programs.begin(), programs.end(), program_ptrs.begin(), [](auto& program) { return &program; }); | ||
TT_FATAL(num_links == 1, "all_gather op is only supported for num_links == 1, but has {}", num_links); | ||
tt::log_info(tt::LogOp, "DEBUG: creating line_fabric with num devices: {}, num links: {}", devices.size(), num_links); | ||
auto line_fabric = ttnn::ccl::EdmLineFabricOpInterface(devices, program_ptrs, num_links); | ||
tt::log_info(tt::LogOp, "DEBUG: line_fabric is created"); | ||
|
||
operation::launch_op( | ||
[dim, num_links, memory_config, devices, ccl_topology]( | ||
const std::vector<Tensor>& input_tensors, | ||
const std::vector<std::optional<const Tensor>>& optional_input_tensors, | ||
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> { | ||
|
||
const auto& input_tensor = input_tensors.at(0); | ||
|
||
return operation::run( | ||
ttnn::ccl::all_gather_detail::create_all_gather_struct(input_tensor, dim, num_links, memory_config, devices, ccl_topology), | ||
{input_tensor}); | ||
}, | ||
{input_tensor}, | ||
output_tensors); | ||
return output_tensors.at(0); | ||
} | ||
|
||
Tensor all_gather_v2( | ||
const Tensor& input_tensor, | ||
const uint32_t dim, | ||
const uint32_t cluster_axis, | ||
const MeshDevice& mesh_device, | ||
const uint32_t num_links, | ||
const std::optional<MemoryConfig>& memory_config, | ||
const std::optional<size_t> user_defined_num_workers, | ||
const std::optional<size_t> user_defined_num_buffers_per_channel, | ||
const ttnn::ccl::Topology topology) { | ||
|
||
tt::log_info(tt::LogOp, "DEBUG: all_gather with cluster_axis is called"); | ||
|
||
TT_FATAL(topology == ttnn::ccl::Topology::Linear, "This all_gather API with cluster_axis is currently supported only for the Linear topology"); | ||
const auto mesh_view = mesh_device.get_view(); | ||
std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols(); | ||
|
||
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; | ||
|
||
operation::launch_op( | ||
[dim, num_links, memory_config, mesh_view, cluster_axis, num_devices, topology]( | ||
const std::vector<Tensor>& input_tensors, | ||
const std::vector<std::optional<const Tensor>>& optional_input_tensors, | ||
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> { | ||
|
||
const auto& input_device_tensor = input_tensors.at(0); | ||
|
||
const auto coordinate = mesh_view->find_device(input_device_tensor.device()->id()); | ||
const auto view_index = (cluster_axis == 0) ? coordinate.col : coordinate.row; | ||
const auto device_index = (cluster_axis == 0) ? coordinate.row : coordinate.col; | ||
|
||
auto get_chip_id = [&](std::size_t line_index) -> std::optional<chip_id_t> { | ||
auto new_coord = coordinate; | ||
if (cluster_axis == 0) { | ||
new_coord.row = line_index % num_devices; | ||
} else { | ||
new_coord.col = line_index % num_devices; | ||
} | ||
return mesh_view->find_device_id(new_coord); | ||
}; | ||
|
||
bool is_last_chip_in_clockwise_direction = device_index == (num_devices - 1); | ||
bool is_last_chip_in_counter_clockwise_direction = device_index == 0; | ||
auto receiver_device_id = is_last_chip_in_clockwise_direction ? std::nullopt : get_chip_id(device_index + 1); | ||
auto sender_device_id = is_last_chip_in_counter_clockwise_direction ? std::nullopt : get_chip_id(device_index + num_devices - 1); | ||
|
||
return operation::run( | ||
ttnn::AllGatherV2{ | ||
dim, num_links, num_devices, device_index, receiver_device_id, sender_device_id, memory_config.value_or(input_device_tensor.memory_config()), topology}, | ||
{input_device_tensor}); | ||
}, | ||
{input_tensor}, | ||
output_tensors); | ||
return output_tensors.at(0); | ||
|
||
} | ||
|
||
|
||
} // namespace ccl | ||
} // namespace operations | ||
|
||
} // namespace ttnn |
98 changes: 98 additions & 0 deletions
98
ttnn/cpp/ttnn/operations/ccl/all_gather_v2/device/all_gather_op.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <cstdint> | ||
#include "common/core_coord.hpp" | ||
#include "impl/buffers/buffer.hpp" | ||
#include "ttnn/tensor/tensor.hpp" | ||
#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include "tt_metal/host_api.hpp" | ||
#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" | ||
#include "ttnn/operations/ccl/ccl_common.hpp" | ||
#include "ttnn/operations/ccl/ccl_op_fusion.hpp" | ||
|
||
|
||
#include "ttnn/run_operation.hpp" | ||
|
||
#include <optional> | ||
#include <vector> | ||
#include <algorithm> | ||
|
||
namespace ttnn { | ||
|
||
using ccl::EriscDatamoverBuilder; | ||
|
||
struct AllGatherV2 { | ||
const uint32_t dim; | ||
const uint32_t num_links; | ||
const uint32_t ring_size; | ||
const uint32_t ring_index; | ||
const std::optional<chip_id_t> receiver_device_id; | ||
const std::optional<chip_id_t> sender_device_id; | ||
const MemoryConfig output_mem_config; | ||
const ccl::Topology topology; | ||
|
||
void validate(const std::vector<Tensor> &input_tensors) const; | ||
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const; | ||
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const; | ||
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const; | ||
}; | ||
|
||
namespace ccl{ | ||
namespace all_gather_v2_detail{ | ||
AllGatherV2 create_all_gather_struct( | ||
const Tensor& input_tensor, | ||
const uint32_t dim, | ||
const uint32_t num_links, | ||
const std::optional<MemoryConfig>& memory_config, | ||
const std::vector<Device*>& devices, | ||
const ccl::Topology topology | ||
); | ||
} // namespace all_gather_detail | ||
} // namespace ccl | ||
|
||
// All Gather Variants | ||
operation::ProgramWithCallbacks all_gather_multi_core_with_workers_new( | ||
const Tensor& input_tensor, | ||
Tensor& output_tensor, | ||
const uint32_t dim, | ||
const uint32_t num_links, | ||
const uint32_t ring_size, | ||
const uint32_t ring_index, | ||
const std::optional<chip_id_t> receiver_device_id, | ||
const std::optional<chip_id_t> sender_device_id, | ||
ccl::Topology topology); | ||
|
||
|
||
|
||
namespace operations { | ||
namespace ccl { | ||
|
||
Tensor all_gather_v2( | ||
const Tensor& input_tensor, | ||
const uint32_t dim, | ||
const uint32_t num_links = 1, | ||
const std::optional<MemoryConfig>& memory_config = std::nullopt, | ||
const std::optional<size_t> user_defined_num_workers = std::nullopt, | ||
const std::optional<size_t> user_defined_num_buffers_per_channel = std::nullopt, | ||
const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring); | ||
|
||
Tensor all_gather_v2( | ||
const Tensor& input_tensor, | ||
const uint32_t dim, | ||
const uint32_t cluster_axis, | ||
const MeshDevice& mesh_device, | ||
const uint32_t num_links = 1, | ||
const std::optional<MemoryConfig>& memory_config = std::nullopt, | ||
const std::optional<size_t> user_defined_num_workers = std::nullopt, | ||
const std::optional<size_t> user_defined_num_buffers_per_channel = std::nullopt, | ||
const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Linear); | ||
|
||
} // namespace ccl | ||
} // namespace operations | ||
|
||
} // namespace ttnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters