diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py new file mode 100644 index 00000000000..31a03d62a74 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +from loguru import logger +import ttnn +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc +from models.utility_functions import skip_for_grayskull + + +def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout): + if layout == ttnn.ROW_MAJOR_LAYOUT and input_dtype == ttnn.bfloat8_b: + return True, "Invalid combination" + + if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0): + return True, "Unsupported test case" + + ## Check that we can readback results + fast_dispatch_page_size_limit = 55 * 1024 + elem_size = 2 if input_dtype == ttnn.bfloat16 else 1 + if layout == ttnn.ROW_MAJOR_LAYOUT and (input_shape[dim] * elem_size) > fast_dispatch_page_size_limit: + # Fast dispatch currently can't breakup readback of large pages into multiple smaller pages and is + # limited to ~55K pages. + return True, "Fast dispatch can't support reading back this page size in one shot" + + # Check that we can fit in L1 (if L1 config) + tensor_size_bytes = elem_size + for i in input_shape: + tensor_size_bytes *= i + num_l1_banks = 64 + if mem_config.buffer_type == ttnn.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024: + return True, "L1 buffer can't support large tensor sizes" + + # Check that each chip has a non-zero amount of data available + min_sized_chunks_on_dim = input_shape[dim] + if dim == 3: + min_sized_chunks_on_dim //= 32 + if dim == 2: + if layout == ttnn.TILE_LAYOUT: + min_sized_chunks_on_dim //= 32 + if min_sized_chunks_on_dim < num_devices: + return ( + True, + f"Input shape {input_shape} incompatible with {num_devices} on dim {dim} because some chips will have no tensor", + ) + + if input_shape == [8, 8, 256, 384] and dim == 1 and layout == ttnn.TILE_LAYOUT and input_dtype == ttnn.bfloat8_b: + return True, "Known failure" + + return False, "" + + +def run_all_gather_impl( + mesh_device, + num_devices, + output_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + all_gather_topology, + num_iters=1, + enable_async=False, + trace_mode=False, + rand_tensor=True, +): + if num_iters < 1: + pytest.fail("num_iters must be >= 1") + # Use Async mode based on test input config + mesh_device.enable_async(enable_async) + + if enable_async: + logger.info(f"Using Async Mode for All Gather Op Dispatch") + + logger.info(f"Output shape: {output_shape}") + logger.info(f"dim: {dim}") + + if rand_tensor: + output_tensor = torch.rand(output_shape).bfloat16() + else: + output_tensor = torch.zeros(output_shape) + tile_id = 1 + for w in range(output_shape[0]): + for z in range(output_shape[1]): + for y in range(0, output_shape[2], 32): + for x in range(0, output_shape[3], 32): + output_tensor[w, z, y : y + 32, x : x + 32] = tile_id + tile_id += 1 + + input_tensors = torch.chunk(output_tensor, num_devices, dim) + tt_input_tensors = [] + for i, t in enumerate(input_tensors): + tt_input_tensors.append(ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], mem_config)) + logger.info(f"using device {mesh_device.get_devices()[i].id()}") + + input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) + if trace_mode: + tt_out_tensor = run_with_trace( + mesh_device, + all_gather_topology, + input_tensor_mesh, + dim, + num_links, + mem_config, + ) + else: + for i in range(num_iters): + tt_out_tensor = ttnn.all_gather( + input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config, topology=all_gather_topology + ) + + for d in mesh_device.get_devices(): + ttnn.synchronize_device(d) + logger.info(f"Done iteration {i}") + + for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)): + tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + logger.info(f"Checking for device {t.device().id()}") + + # breakpoint() + + # for current non-edm version of all gather only: + chunked_output_tensor = torch.chunk(tt_output_tensor, num_devices, dim) + + if input_dtype == ttnn.bfloat16: + # eq, output = comp_equal(tt_output_tensor, output_tensor) + eq, output = comp_equal(chunked_output_tensor[i], input_tensors[i]) + else: + # eq, output = comp_pcc(tt_output_tensor, output_tensor) + eq, output = comp_pcc(chunked_output_tensor[i], input_tensors[i]) + if not eq: + logger.error(f"output mismatch for tensor {i}") + assert eq, f"{i} FAILED: {output}" + + +# Enumerate the post-commit cases explicitly +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.parametrize( + "num_devices, num_links, output_shape, dim, layout", + [ + # Known errors + # - double/tripple buffers in cb not working + # (4, 2, [4, 1, 256, 32], 0, ttnn.TILE_LAYOUT), # failed: device not connected # https://github.com/tenstorrent/tt-metal/issues/9686 + (2, 1, [1, 1, 32, 256], 3, ttnn.TILE_LAYOUT), + (2, 1, [1, 1, 64, 256], 2, ttnn.TILE_LAYOUT), + (8, 1, [8, 1, 256, 32], 0, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + (8, 1, [1, 8, 256, 32], 1, ttnn.TILE_LAYOUT), + (2, 2, [1, 1, 32, 256], 3, ttnn.TILE_LAYOUT), + (2, 2, [1, 1, 64, 256], 2, ttnn.TILE_LAYOUT), + (2, 2, [1, 1, 32, 320], 3, ttnn.TILE_LAYOUT), + (2, 1, [1, 1, 32, 320], 3, ttnn.TILE_LAYOUT), + # (4, 3, [1, 1, 32, 16384 * 4], 3, ttnn.TILE_LAYOUT), # failed: device not connected + (8, 4, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT), + (8, 3, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT), + (8, 2, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT), + # untested cases + # (4, 2, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (4, 2, [4, 1, 256, 32], 0, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [8, 1, 256, 32], 0, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (8, 1, [1, 1, 32, 16384], 3, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + # (4, 2, [1, 1, 32, 32768], 3, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686 + ], +) +@pytest.mark.parametrize( + "input_dtype", + [ + ttnn.bfloat16, + # ttnn.bfloat8_b, # https://github.com/tenstorrent/tt-metal/issues/9686 + ], +) +@pytest.mark.parametrize( + "mem_config", + [ + ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), # https://github.com/tenstorrent/tt-metal/issues/9686 + # ttnn.MemoryConfig(buffer_type=ttnn.BufferType.L1), + ], +) +@pytest.mark.parametrize("num_iters", [1]) # restore to 500: https://github.com/tenstorrent/tt-metal/issues/9686 +@pytest.mark.parametrize("enable_async", [True]) +def test_all_gather( + t3k_mesh_device, + # pcie_mesh_device, + num_devices, + output_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + num_iters, + use_program_cache, + function_level_defaults, + enable_async, +): + run_all_gather_impl( + t3k_mesh_device, + num_devices, + output_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + all_gather_topology=ttnn.Topology.Ring, + num_iters=num_iters, + enable_async=enable_async, + rand_tensor=True, + ) + + run_all_gather_impl( + t3k_mesh_device, + num_devices, + output_shape, + dim, + num_links, + input_dtype, + layout, + mem_config, + use_program_cache, + function_level_defaults, + all_gather_topology=ttnn.Topology.Ring, + num_iters=num_iters, + enable_async=enable_async, + rand_tensor=False, + ) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 8dbf03025a2..1b5b793e098 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -15,6 +15,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core_new.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp @@ -27,6 +28,8 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/uops/ccl_command.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 4957355cf7e..5672e7165a9 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -167,7 +167,7 @@ std::vector AllGather::create_output_tensors(const std::vector & } operation::ProgramWithCallbacks AllGather::create_program(const std::vector & input_tensors, std::vector &output_tensors) const { - return all_gather_multi_core_with_workers(input_tensors[0], output_tensors[0], this->dim, this->num_links, this->ring_size, this->ring_index, this->receiver_device_id, this->sender_device_id, this->topology, this->user_defined_num_workers, this->user_defined_num_buffers_per_channel); + return all_gather_multi_core_with_workers_new(input_tensors[0], output_tensors[0], this->dim, this->num_links, this->ring_size, this->ring_index, this->receiver_device_id, this->sender_device_id, this->topology, this->user_defined_num_workers, this->user_defined_num_buffers_per_channel); } namespace operations { diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index b0a162f2a1f..0e3ff0ff8bf 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -192,6 +192,18 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( const std::optional user_defined_num_buffers_per_channel, std::optional& fused_op_signaler, const CoreCoord core_grid_offset = CoreCoord(0, 0)); +operation::ProgramWithCallbacks all_gather_multi_core_with_workers_new( + const Tensor& input_tensor, + Tensor& output_tensor, + const uint32_t dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + const std::optional receiver_device_id, + const std::optional sender_device_id, + ccl::Topology topology, + const std::optional user_defined_num_workers, + const std::optional user_defined_num_buffers_per_channel); diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp index 75606f8b23c..e7d4ba65679 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp @@ -8,6 +8,7 @@ #include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" #include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" #include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" +#include "debug/dprint.h" using ttnn::ccl::ShardType; using ttnn::ccl::UNINITIALIZED_VALUE_U16; @@ -535,7 +536,7 @@ template FORCE_INLINE void write_wrapped_chunk( uint32_t& curr_page_idx, uint32_t& offset_into_worker_slice, - ttnn::ccl::coord_t& offset_worker_slice, + const ttnn::ccl::coord_t& offset_worker_slice, const ttnn::ccl::coord_t& worker_slice_shape, // In tiles for tile layout diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core_new.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core_new.cpp new file mode 100644 index 00000000000..e73ff4ad121 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core_new.cpp @@ -0,0 +1,199 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 +/// +#include + +#include "tt_metal/common/core_coord.hpp" +#include "eth_l1_address_map.h" +#include "impl/buffers/buffer.hpp" +#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" +#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/math.hpp" +#include "tt_metal/common/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" + +#include +#include +#include + +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" + + +using namespace tt::constants; + +namespace ttnn { + +using namespace ccl; +// For ring all-gather, we can send sub-sections of input tensor in opposite directions +// For linear all-gather though, we must ensure we send full tensors in BOTH directions +// (in other words, disable the "bidirectional" send flag) +operation::ProgramWithCallbacks all_gather_multi_core_with_workers_new( + const Tensor& input_tensor, + Tensor& output_tensor, + const uint32_t dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + const std::optional receiver_device_id, + const std::optional sender_device_id, + ccl::Topology topology, + const std::optional user_defined_num_workers, + const std::optional user_defined_num_buffers_per_channel) { + + // Sleep for 5 * ring_index seconds (for DEBUG only) + // std::chrono::seconds sleep_duration(5 * ring_index); + // std::this_thread::sleep_for(sleep_duration); + + // // Log device id and ring index + // log_info(tt::LogOp, "Generating log for Ring Index: {}", ring_index); + + tt::tt_metal::Program program{}; + + TT_FATAL(!(receiver_device_id == std::nullopt && sender_device_id == std::nullopt), "At least one of receiver_device_id or sender_device_id must be specified"); + + std::unique_ptr input_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(input_tensor); + std::unique_ptr output_tensor_config = ttnn::ccl::CclOpTensorConfig::build_all_gather_tensor_config(output_tensor); + + const auto& device = input_tensor.device(); + + bool is_sharded = input_tensor.is_sharded(); + + const auto input_buffer = input_tensor.buffer(); + const auto output_buffer = output_tensor.buffer(); + + // Get OP Config, topology config + std::vector input_tensors = {input_tensor}; + std::vector output_tensors = {output_tensor}; + auto const& op_config =ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + auto const& input_tensor_partition = ttnn::ccl::TensorPartition(1, 0); // one partition, 0 index + auto const& output_tensor_partition = ttnn::ccl::TensorPartition(ring_size, ring_index); // ring_size partitions, ring_index index + + // Get worker cores, assuming 1 worker per link + uint32_t num_workers_per_link = 1; + auto const& sender_worker_core_range = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(num_workers_per_link-1, num_links - 1))); + auto const& sender_worker_cores = corerange_to_cores(sender_worker_core_range, std::nullopt, true); + + // CB Creation + uint32_t num_pages_per_packet = 1; // we assume 1 page per packet for now + uint32_t cb_num_pages = num_pages_per_packet; // There is a bug with double/tripple buffering. Still debugging. + uint32_t src0_cb_index = tt::CB::c_in0; + uint32_t page_size_bytes = op_config.get_page_size(); + tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_num_pages * page_size_bytes, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, page_size_bytes); + CBHandle cb_src0_workers = CreateCircularBuffer(program, sender_worker_core_range, cb_src0_config); + + // Create Tensor slicer + auto input_tensor_slicer = ttnn::ccl::GenericWrappedTensorSlicer ( // can be used for all gather as well if we set ring_size to 1, which means read the entire input tensor in reduce scatter sense. + input_tensor, + input_tensor, + dim, + 0, // ring_index, set as 0 to "trick" the reduce scatter tensor slicer to read the entire input tensor + 1, // ring_size, set as 1 to "trick" the reduce scatter tensor slicer to read the entire input tensor + num_links, // num_workers_per_slicer, set 1 per link for now + UINT32_MAX, // max_worker_slice_in_bytes, set as infinite for now + cb_num_pages / 2); + auto output_tensor_slicer = ttnn::ccl::GenericWrappedTensorSlicer ( + output_tensor, + output_tensor, + dim, + ring_index, + ring_size, + num_links, // num_workers_per_slicer, set 1 per link for now + UINT32_MAX, // max_worker_slice_in_bytes, set as infinite for now + cb_num_pages / 2); + + // KERNEL CREATION + auto worker_arg_builder = ccl::worker_detail::CCLWorkerArgBuilder( + device, + op_config, + input_tensor_partition, + output_tensor_partition, + dim); + + auto const& worker_defines = op_config.emit_worker_defines(); + static std::string const& sender_kernel_reader_path = "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp"; + static std::string const& sender_kernel_writer_path = "ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp"; + + KernelHandle worker_sender_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + sender_kernel_reader_path, + sender_worker_core_range, + tt::tt_metal::ReaderDataMovementConfig(worker_arg_builder.generate_sender_reader_kernel_ct_args(), worker_defines)); + + KernelHandle worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + sender_kernel_writer_path, + sender_worker_core_range, + tt::tt_metal::WriterDataMovementConfig(worker_arg_builder.generate_sender_writer_kernel_ct_args(), worker_defines)); + + // RT Args + for (std::size_t link = 0; link < num_links; link++) { + CoreCoord core = {num_workers_per_link-1, link}; + std::size_t worker_tensor_slice_index = link; + auto const& input_worker_slice = input_tensor_slicer.get_worker_slice(worker_tensor_slice_index); + auto const& output_worker_slice = output_tensor_slicer.get_worker_slice(worker_tensor_slice_index); + auto worker_arg_builder = ccl::worker_detail::CCLWorkerArgBuilder( + device, + op_config, + input_tensor_partition, + output_tensor_partition, + dim); + + // tt::log_info("Creating RT Args for worker core ({},{})", core.x, core.y); + // input_worker_slice.print(); + // output_worker_slice.print(); + + auto const sender_reader_rt_args = worker_arg_builder.generate_sender_reader_kernel_rt_args(input_worker_slice, worker_arg_builder.operating_dim, num_pages_per_packet, worker_tensor_slice_index); + tt::tt_metal::SetRuntimeArgs( + program, + worker_sender_reader_kernel_id, + core, + sender_reader_rt_args); + + auto const sender_writer_rt_args = worker_arg_builder.generate_sender_writer_kernel_rt_args(output_worker_slice, worker_arg_builder.operating_dim, num_pages_per_packet, worker_tensor_slice_index); + tt::tt_metal::SetRuntimeArgs( + program, + worker_sender_writer_kernel_id, + core, + sender_writer_rt_args); + } + + auto override_runtime_arguments_callback = + [worker_sender_reader_kernel_id, + worker_sender_writer_kernel_id, + sender_worker_cores] ( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors + ) { + const auto& input = input_tensors[0]; + const auto& output = output_tensors[0]; + + // update senders + auto &worker_reader_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_reader_kernel_id); + auto &worker_writer_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_writer_kernel_id); + for (auto const& core : sender_worker_cores) { + // reader + auto& worker_reader_sender_runtime_args = worker_reader_sender_runtime_args_by_core[core.x][core.y]; + worker_reader_sender_runtime_args.at(0) = input.buffer()->address(); + // writer + auto& worker_writer_sender_runtime_args = worker_writer_sender_runtime_args_by_core[core.x][core.y]; + worker_writer_sender_runtime_args.at(0) = output.buffer()->address(); + } + }; + + return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; +} + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp index 6c49072b809..0057bf91fc8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.cpp @@ -858,5 +858,238 @@ std::vector generate_slice_sequence_on_dim( return slices; } + +/* + * @brief: Given a tensor shape, evenly break it into pieces along a given dimension and generate the slices accordingly. + * This can be fed into a CCL Send command generator + */ +std::vector generate_slice_sequence_on_dim_v2( + TensorSlice::ords_t tensor_shape, + TensorSlice::ords_t worker_slice_shape, + TensorSlice::ords_t worker_slice_offset, + std::size_t fracture_dim, + std::size_t num_slices, + std::int64_t start_slice_index, + std::int64_t end_slice_index_exclusive, + std::size_t worker_index +) { + static_assert(std::is_same_v, "generate_slice_sequence_on_dim_v2 not yet implemented for type not of tt_xy_pair"); + // We don't support 4D shapes in the CCL kernels yet, which are needed for proper reduction/concatenation in some cases + // so for now we subtract the outer dims from the fracture_dim since we only support 2D at the moment. + if (fracture_dim == 3) { + fracture_dim -= 2; + } else { + // dims are + fracture_dim = 0; + } + + TT_ASSERT(worker_slice_shape.y == 1); + + std::vector slices; + auto dim_size = fracture_dim == 1 ? tensor_shape.x : tensor_shape.y; + TT_ASSERT(dim_size % num_slices == 0); + auto slice_size_on_dim = dim_size / num_slices; + auto slice_shape = fracture_dim == 0 ? tt_xy_pair{tensor_shape.x, slice_size_on_dim} : tt_xy_pair{slice_size_on_dim, tensor_shape.y}; + + auto dim_start_offset = start_slice_index * slice_size_on_dim; + TensorSlice::ords_t tensor_slice_offset = fracture_dim == 0 ? tt_xy_pair{0, dim_start_offset} : tt_xy_pair{dim_start_offset, 0}; + + bool forward_direction = start_slice_index > end_slice_index_exclusive; // only for debug + auto incr = start_slice_index < end_slice_index_exclusive ? 1 : -1; + if (forward_direction) { + log_trace(tt::LogOp, "slice_size_on_dim {}", slice_size_on_dim); + log_trace(tt::LogOp, "worker_index {}", worker_index); + } + + auto worker_slice_start_offset = worker_slice_offset; + + auto generate_slice = [forward_direction,incr, &slices, &tensor_shape, &slice_shape, &worker_slice_shape, tensor_slice_offset, &worker_slice_start_offset, fracture_dim, dim_start_offset, slice_size_on_dim](std::int64_t i){ + auto tensor_slice_offset_adjusted = tensor_slice_offset; + if (fracture_dim == 0) { + tensor_slice_offset_adjusted.y = slice_size_on_dim * i; + } else { + tensor_slice_offset_adjusted.x = slice_size_on_dim * i; + } + TT_ASSERT(tensor_shape.x > 0, "Invalid tensor shape. x = 0 but it must be > 0"); + TT_ASSERT(tensor_shape.y > 0, "Invalid tensor shape. y = 0 but it must be > 0"); + TT_ASSERT(slice_shape.x > 0, "Invalid tensor slice shape. x = 0 but it must be > 0"); + TT_ASSERT(slice_shape.y > 0, "Invalid tensor slice shape. x = 0 but it must be > 0"); + TT_ASSERT(tensor_slice_offset_adjusted.x < tensor_shape.x, "Invalid tensor slice offset. x = {} but it must be < tensor shape x={}. slice_offset: (y={},x={}), tensor_shape: (y={},x={}). slice_size_on_dim: {}, i: {}", tensor_slice_offset_adjusted.x, tensor_shape.x, tensor_slice_offset_adjusted.y, tensor_slice_offset_adjusted.x, tensor_shape.y, tensor_shape.x, slice_size_on_dim, i); + TT_ASSERT(tensor_slice_offset_adjusted.y < tensor_shape.y, "Invalid tensor slice offset. y = {} but it must be < tensor shape y={}. slice_offset: (y={},x={}), tensor_shape: (y={},x={}). slice_size_on_dim: {}, i: {}", tensor_slice_offset_adjusted.y, tensor_shape.y, tensor_slice_offset_adjusted.y, tensor_slice_offset_adjusted.x, tensor_shape.y, tensor_shape.x, slice_size_on_dim, i); + TT_ASSERT(worker_slice_shape.x > 0, "Invalid worker slice shape. x = 0 but it must be > 0"); + TT_ASSERT(worker_slice_shape.y > 0, "Invalid worker slice shape. y = 0 but it must be > 0"); + + auto const& tensor_slice = TensorSlice(tensor_shape, slice_shape, tensor_slice_offset_adjusted, worker_slice_shape, worker_slice_start_offset, fracture_dim); + if (forward_direction) { + log_trace( + tt::LogOp, + "generate_slice ({}):\n\ttensor_shape: (y={},x={})\n\ttensor_slice_shape: (y={},x={})\n\ttensor_slice_offset_adjusted: (y={},x={})\n\tslice_start_shape: (y={},x={})\n\tworker relative slice_start_offset: (y={},x={})\n\tfracture_dim: {}\n\tdim_start_offset: {}\n\tslice_size_on_dim: {}\n", + i, + tensor_slice.tensor_shape.y, + tensor_slice.tensor_shape.x, + tensor_slice.tensor_slice_shape.y, + tensor_slice.tensor_slice_shape.x, + tensor_slice.tensor_slice_offset.y, + tensor_slice.tensor_slice_offset.x, + tensor_slice.worker_slice_shape.y, + tensor_slice.worker_slice_shape.x, + tensor_slice.worker_slice_offset.y, + tensor_slice.worker_slice_offset.x, + fracture_dim, + dim_start_offset, + slice_size_on_dim); + } + + slices.push_back(tensor_slice); + }; + + for (int i = start_slice_index; i != end_slice_index_exclusive; i += incr) { + generate_slice(i); + } + + return slices; +} + + +GenericWrappedTensorSlicer::GenericWrappedTensorSlicer( + const Tensor& input_tensor, + const Tensor& output_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers, + uint32_t max_slice_size_in_bytes, + uint32_t half_cb_n_pages) +{ + this->initialize(input_tensor, output_tensor, slice_dim, partition_index, partition_size, total_num_workers, max_slice_size_in_bytes, half_cb_n_pages); +} + +tt_xy_pair GenericWrappedTensorSlicer::calculate_tensor_slice_shape(const Tensor& input_tensor, int slice_dim, uint32_t partition_size) { + const uint32_t num_tiles_x = input_tensor.get_legacy_shape()[-1] / tt::constants::TILE_WIDTH; + uint32_t num_tiles_y = (input_tensor.get_legacy_shape()[-2] / tt::constants::TILE_HEIGHT); + for (std::size_t i = 0; input_tensor.get_legacy_shape().rank() > 2 && i < input_tensor.get_legacy_shape().rank() - 2; i++) { + num_tiles_y *= input_tensor.get_legacy_shape()[i]; + } + TT_ASSERT(num_tiles_x >= partition_size); + tt_xy_pair tensor_slice_shape; + tensor_slice_shape.x = slice_dim == 3 ? (num_tiles_x / partition_size) : num_tiles_x; + tensor_slice_shape.y = slice_dim != 3 ? num_tiles_y / partition_size : num_tiles_y; + return tensor_slice_shape; +} + +void GenericWrappedTensorSlicer::initialize( + const Tensor& input_tensor, + const Tensor& output_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers, + uint32_t max_slice_size_in_bytes, + uint32_t half_cb_n_pages) +{ + // Configure layout parameters + this->row_major = (input_tensor.get_layout() == Layout::ROW_MAJOR); + this->input_page_size = input_tensor.buffer()->page_size(); + this->partition_index = partition_index; + this->partition_size = partition_size; + + // Assume everything in Tile layout for now, row major not supported yet + TT_FATAL(!this->row_major, "Row major not supported yet"); + + this->tensor_slice_shape = calculate_tensor_slice_shape(input_tensor, slice_dim, partition_size); + + // Calculate worker slice shapes (tile layout) + this->worker_slice_shapes = create_worker_slice_shapes_for_tile_layout( + input_tensor.get_legacy_shape(), + this->tensor_slice_shape, + total_num_workers, + max_slice_size_in_bytes / this->input_page_size, + half_cb_n_pages + ); + + // Flattened tensor shape (tile layout) + this->flattened_tensor_shape = tt_xy_pair{ + input_tensor.get_legacy_shape()[3] /tt::constants::TILE_WIDTH, + (input_tensor.get_legacy_shape()[0] * input_tensor.get_legacy_shape()[1] * + input_tensor.get_legacy_shape()[2]) / + tt::constants::TILE_HEIGHT}; + + this->worker_slice_offsets = compute_worker_slice_offsets(this->worker_slice_shapes, this->tensor_slice_shape); +} + +ccl::InterleavedTensorWorkerSlice GenericWrappedTensorSlicer::get_worker_slice(std::size_t global_worker_index) { + assert(global_worker_index < this->worker_slice_shapes.size()); + assert(global_worker_index < this->worker_slice_offsets.size()); + return ccl::InterleavedTensorWorkerSlice( + this->flattened_tensor_shape, + this->tensor_slice_shape, + this->worker_slice_shapes[global_worker_index], + this->worker_slice_offsets[global_worker_index], + true // wrapped + ); +} + +std::vector GenericWrappedTensorSlicer::compute_worker_slice_offsets( + std::vector const& worker_slice_shapes, tt_xy_pair const& tensor_slice_shape) { + return compute_worker_slice_offsets_for_wrapped_tensor_slicer(worker_slice_shapes, tensor_slice_shape); +} + +std::vector GenericWrappedTensorSlicer::create_worker_slice_shapes_for_tile_layout( + tt::tt_metal::LegacyShape const& tensor_shape, + tt_xy_pair const& tensor_slice_shape_in_tiles, + uint32_t num_workers, + uint32_t max_slice_size_in_pages, + uint32_t half_cb_n_pages) +{ + log_trace(tt::LogOp, "\tmax_slice_size_in_pages={}", max_slice_size_in_pages); + TT_ASSERT(max_slice_size_in_pages > 0); + std::vector worker_slice_shapes; + worker_slice_shapes.reserve(num_workers); + const uint32_t total_num_tiles = tensor_slice_shape_in_tiles.x * tensor_slice_shape_in_tiles.y; + if (num_workers > total_num_tiles) { + log_warning( + tt::LogOp, + "Reduce Scatter more workers instantiated than is work to be done. Some workers will be idle and do " + "nothing"); + for (uint32_t w = 0; w < total_num_tiles; ++w) { + worker_slice_shapes.emplace_back(1, 1); + } + for (uint32_t w = total_num_tiles; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(0, 0); + } + return worker_slice_shapes; + } + + std::size_t max_slice_size_in_tiles = max_slice_size_in_pages; + + // Assign slices by assuming that the input tensor is flattened into a 1D Shape + std::size_t optim_worker_slice_len_tiles = std::ceil(static_cast(total_num_tiles) / num_workers); // Ceil so that the remainder worker will have a smaller slice + + log_trace(tt::LogOp, "---- GenericWrappedTensorSlicer::create_worker_slice_shapes_for_tile_layout ---- "); + log_trace(tt::LogOp, "total_num_tiles: {}", total_num_tiles); + log_trace(tt::LogOp, "num_workers: {}", num_workers); + log_trace(tt::LogOp, "optim_worker_slice_len_tiles: {}", optim_worker_slice_len_tiles); + + if (max_slice_size_in_tiles < optim_worker_slice_len_tiles) { // Each worker will have a full slice + for (uint32_t w = 0; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(max_slice_size_in_tiles, 1); + } + } else { // Each worker will only have one slice + uint32_t remainder_worker_len_tiles = total_num_tiles % optim_worker_slice_len_tiles; + + for (uint32_t w = 0; w < num_workers; ++w) { + worker_slice_shapes.emplace_back(optim_worker_slice_len_tiles, 1); + } + // If there is a remainder worker, we need to adjust the last worker's slice shape to be smaller + if (remainder_worker_len_tiles > 0) { + worker_slice_shapes.back() = tt_xy_pair{remainder_worker_len_tiles, 1}; + } + } + + log_trace(tt::LogOp, "--------------------------------"); + + return worker_slice_shapes; +} + } // namespace ccl } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp index 3f71a810bb2..32bfdd12688 100644 --- a/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp @@ -51,6 +51,17 @@ struct RingTopology { bool is_linear; }; +struct TensorPartition { + TensorPartition( + uint32_t partition_size, + uint32_t partition_index) + : partition_size(partition_size), + partition_index(partition_index) {} + + uint32_t partition_size; + uint32_t partition_index; +}; + class CclOpTensorConfig { public: static std::unique_ptr build_all_gather_tensor_config(Tensor const& tensor); @@ -264,6 +275,16 @@ struct InterleavedTensorWorkerSlice { return worker_slice_shape.x * worker_slice_shape.y; } + void print() const { + tt::log_info("----- printing worker slice -----"); + tt::log_info("tensor_shape: ({},{})", tensor_shape.x, tensor_shape.y); + tt::log_info("tensor_slice_shape: ({},{})", tensor_slice_shape.x, tensor_slice_shape.y); + tt::log_info("worker_slice_shape: ({},{})", worker_slice_shape.x, worker_slice_shape.y); + tt::log_info("worker_slice_offset: ({},{})", worker_slice_offset.x, worker_slice_offset.y); + tt::log_info("worker_slice_is_wrapped: {}", worker_slice_is_wrapped); + tt::log_info("worker_slice_num_pages: {}", get_worker_slice_num_pages()); + } + tt_xy_pair tensor_shape; tt_xy_pair tensor_slice_shape; tt_xy_pair worker_slice_shape; @@ -500,5 +521,58 @@ ccl::EriscDatamoverBuilder create_erisc_datamover_builder( ccl::EriscDataMoverBufferSharingMode buffer_sharing_mode, EriscDataMoverTerminationMode termination_mode); + +class GenericWrappedTensorSlicer { +public: + GenericWrappedTensorSlicer( + const Tensor& input_tensor, + const Tensor& output_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers, + uint32_t max_slice_size_in_bytes, + uint32_t half_cb_n_pages); + + ccl::InterleavedTensorWorkerSlice get_worker_slice(std::size_t global_worker_index); + + // method to compute offsets in a wrapped layout + std::vector compute_worker_slice_offsets( + const std::vector& worker_slice_shapes, + tt_xy_pair const& tensor_slice_shape); + + // method to create worker slice shapes in a tile layout + std::vector create_worker_slice_shapes_for_tile_layout( + const tt::tt_metal::LegacyShape& tensor_shape, + tt_xy_pair const& tensor_slice_shape_in_tiles, + uint32_t num_workers, + uint32_t max_slice_size_in_pages, + uint32_t half_cb_n_pages); + +private: + void initialize( + const Tensor& input_tensor, + const Tensor& output_tensor, + int slice_dim, + uint32_t partition_index, + uint32_t partition_size, + uint32_t total_num_workers, + uint32_t max_slice_size_in_bytes, + uint32_t half_cb_n_pages); + + tt_xy_pair calculate_tensor_slice_shape(const Tensor& input_tensor, int slice_dim, uint32_t partition_size); + + // Class member variables + tt_xy_pair flattened_tensor_shape; + tt_xy_pair tensor_slice_shape; + std::vector worker_slice_shapes; + std::vector worker_slice_offsets; + uint32_t input_page_size; + bool row_major; + uint32_t partition_index; + uint32_t partition_size; +}; + + } // namespace ccl } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp new file mode 100644 index 00000000000..bf10e9dc429 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp @@ -0,0 +1,338 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "hostdevcommon/kernel_structs.h" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" + +#include "ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" + +namespace ttnn { +namespace ccl { +namespace worker_detail { + +CCLWorkerArgBuilder::CCLWorkerArgBuilder ( + Device const* device, + ttnn::ccl::CCLOpConfig const& op_config, + ttnn::ccl::TensorPartition const& input_tensor_partition, + ttnn::ccl::TensorPartition const& output_tensor_partition, + std::size_t operating_dim): + device(device), + op_config(op_config), + input_tensor_partition(input_tensor_partition), + output_tensor_partition(output_tensor_partition), + operating_dim(operating_dim) { +} + +void emit_ccl_send_slice_sequence_commands(std::vector const& slices, std::vector& args_out) { + for (std::size_t i = 0; i < slices.size(); i++) { + auto const& slice = slices[i]; + // Copy the header + if (i == 0) { + const std::size_t args_index_old = args_out.size(); + // push back Command Header + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandHeader::to_uint32( + ttnn::ccl::cmd::CclCommandHeader{ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_EDM, 1}))); + + // push back arg 0 header + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_FULL_TENSOR_SLICE_SPEC_IN_PAGES)); + auto const& ccl_command_tensor = ttnn::ccl::cmd::CclCommandTensor{ + Shape4D(1, 1, slice.tensor_shape.y, slice.tensor_shape.x), + Shape4D(1, 1, slice.tensor_slice_shape.y, slice.tensor_slice_shape.x), + Shape4D(0, 0, slice.tensor_slice_offset.y, slice.tensor_slice_offset.x), + Shape4D(0, 0, slice.worker_slice_offset.y, slice.worker_slice_offset.x), + slice.worker_slice_shape.x * slice.worker_slice_shape.y}; + const auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for full tensor slice command", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + // push_back arg 0 payload + ttnn::ccl::cmd::CclCommandArg:: + pack_to( + &args_out[args_out.size() - num_words_for_args], + ccl_command_tensor + ); + const std::size_t args_index_new = args_out.size(); + + TT_ASSERT(i < slices.size(), "Internal Error"); + std::stringstream ss; ss << "ccl_send command " << std::to_string(i) << " has " << args_index_new - args_index_old << " args:\n"; + for (std::size_t j = args_index_old; j < args_index_new; j++) { + ss << "\targ " << j << ":" << args_out[j] << "\n"; + } + log_trace(tt::LogOp, "{}", ss.str()); + // We can reused cached values for the first slice + } else { + auto const& last_slice = slices[i - 1]; + const std::size_t args_index_old = args_out.size(); + auto header_index = args_out.size(); + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandHeader::to_uint32( + ttnn::ccl::cmd::CclCommandHeader{ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_EDM, 1}))); + std::size_t num_args = 0; + + // tensor shape + if (last_slice.tensor_shape != slice.tensor_shape) { + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_TENSOR_SHAPE_IN_PAGES)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for tensor_shape field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg::pack_to( + &args_out[args_out.size() - num_words_for_args], + Shape4D(1, 1, slice.tensor_shape.y, slice.tensor_shape.x) + ); + for (std::size_t j = args_out.size() - num_words_for_args; j < args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", args_out[j]); + } + + num_args++; + } + + // tensor slice shape + if (last_slice.tensor_slice_shape != slice.tensor_slice_shape) { + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_TENSOR_SLICE_SHAPE_IN_PAGES)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for tensor_slice_shape field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg::pack_to( + &args_out[args_out.size() - num_words_for_args], + Shape4D(1, 1, slice.tensor_slice_shape.y, slice.tensor_slice_shape.x) + ); + for (std::size_t i = args_out.size() - num_words_for_args; i < args_out.size(); i++) { + log_trace(tt::LogOp, "\t{}", args_out[i]); + } + + num_args++; + } + + // tensor slice offset + if (last_slice.tensor_slice_offset != slice.tensor_slice_offset) { + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_TENSOR_SLICE_OFFSET_IN_PAGES)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for tensor_slice_offset field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg::pack_to( + &args_out[args_out.size() - num_words_for_args], + Shape4D(0, 0, slice.tensor_slice_offset.y, slice.tensor_slice_offset.x) + ); + for (std::size_t j = args_out.size() - num_words_for_args; j < args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", args_out[j]); + } + + num_args++; + } + + // worker slice offset + if (last_slice.worker_slice_offset != slice.worker_slice_offset) { + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_WORKER_START_OFFSET_IN_SLICE_IN_PAGES)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for worker_slice_offset field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg::pack_to( + &args_out[args_out.size() - num_words_for_args], + Shape4D(0, 0, slice.worker_slice_offset.y, slice.worker_slice_offset.x) + ); + + for (std::size_t j = args_out.size() - num_words_for_args; j < args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", args_out[j]); + } + num_args++; + } + + // worker_pages_per_slice + if (last_slice.worker_slice_shape != slice.worker_slice_shape) { + args_out.push_back(static_cast(ttnn::ccl::cmd::CclCommandArgCode::SET_WORKER_PAGES_PER_SLICE)); + auto num_words_for_args = ttnn::ccl::cmd::CclCommandArg::size_in_words(); + log_trace(tt::LogOp, "Emitting {} args for worker_pages_per_slice field", num_words_for_args); + args_out.resize(args_out.size() + num_words_for_args); + ttnn::ccl::cmd::CclCommandArg::pack_to( + &args_out[args_out.size() - num_words_for_args], + slice.worker_slice_shape.y * slice.worker_slice_shape.x + ); + for (std::size_t j = args_out.size() - num_words_for_args; j < args_out.size(); j++) { + log_trace(tt::LogOp, "\t{}", args_out[j]); + } + + num_args++; + } + + args_out[header_index] = static_cast(ttnn::ccl::cmd::CclCommandHeader::to_uint32( + ttnn::ccl::cmd::CclCommandHeader{ttnn::ccl::cmd::CclCommandCode::STREAM_TENSOR_TO_EDM, 1})); + + std::size_t args_index_new = args_out.size(); + std::stringstream ss; ss << "ccl_send command " << i << " has " << args_index_new - args_index_old << " args:\n"; + for (std::size_t j = args_index_old; j < args_index_new; j++) { + ss << "\targ " << j << ":" << args_out[j] << "\n"; + } + log_trace(tt::LogOp, "{}", ss.str()); + } + } +} + +std::vector CCLWorkerArgBuilder::generate_sender_reader_kernel_rt_args( + ttnn::ccl::InterleavedTensorWorkerSlice worker_slice, + std::size_t operating_dim, + uint32_t num_pages_per_packet, + uint32_t worker_slice_index) const +{ + const std::size_t num_commands_expected = this->input_tensor_partition.partition_size - 1; + + auto const& tensor_shape = worker_slice.tensor_shape; + auto const& tensor_slice_shape = worker_slice.tensor_slice_shape; + + auto num_slices = input_tensor_partition.partition_size; + auto start_slice_index = input_tensor_partition.partition_index; + std::int64_t end_slice_index_exclusive = input_tensor_partition.partition_index + 1; + + if (input_tensor_partition.partition_index==0){ + log_trace(tt::LogOp, "ccl_send_writer start_slice_index = {}", start_slice_index); + log_trace(tt::LogOp, "ccl_send_writer end_slice_index_exclusive = {}", end_slice_index_exclusive); + } + + // Add the command args + auto const& slices = generate_slice_sequence_on_dim_v2( + tensor_shape, + worker_slice.worker_slice_shape, + worker_slice.worker_slice_offset, + operating_dim, + num_slices, + start_slice_index, + end_slice_index_exclusive, + worker_slice_index + ); + TT_ASSERT(num_commands_expected == slices.size()); + + // If we are on device zero, we send n-1 chunks in ascending order + auto &input_tensor = this->op_config.get_input_tensor(0); + TT_ASSERT(input_tensor.get_legacy_shape().size() == 4, "Only 4D tensors are supported for ccl"); + ttnn::ccl::Shape4D input_tensor_shape = {input_tensor.get_legacy_shape()[0], input_tensor.get_legacy_shape()[1],input_tensor.get_legacy_shape()[2],input_tensor.get_legacy_shape()[3]}; + + std::vector args = { + static_cast(input_tensor.buffer()->address()), + static_cast(slices.size()) + }; + std::size_t logged_arg_idx = 0; + if (input_tensor_partition.partition_index==0) log_trace(tt::LogOp, "ccl_send_reader arg[{}]: buffer_address = {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + if (input_tensor_partition.partition_index==0) log_trace(tt::LogOp, "ccl_send_reader arg[{}]: num_commands = {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + + std::ranges::copy(std::vector{num_pages_per_packet}, std::back_inserter(args)); + if (input_tensor_partition.partition_index==0) log_trace(tt::LogOp, "ccl_send_reader arg[{}]: pages_per_packet {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + + std::ranges::copy(std::vector{this->op_config.get_page_size()}, std::back_inserter(args)); + if (input_tensor_partition.partition_index==0) log_trace(tt::LogOp, "ccl_send_reader arg[{}]: page_size {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + + auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_runtime_args(this->device, input_tensor); + std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); + for (auto const& arg : addr_gen_rt_args) { + if (input_tensor_partition.partition_index==0) log_trace(tt::LogOp, "ccl_send_reader arg[{}]: addr_gen_rt_args[] {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + } + + if (input_tensor_partition.partition_index==0) log_trace(tt::LogOp, "ccl_send_reader Generating {} ccl send commands", slices.size()); + emit_ccl_send_slice_sequence_commands(slices, args); + + if (input_tensor_partition.partition_index==0) log_trace(tt::LogOp, "ccl_send_reader Sender Worker has {} RT Args: {}", args.size(), args); + + return args; +} + +std::vector CCLWorkerArgBuilder::generate_sender_writer_kernel_rt_args( + ttnn::ccl::InterleavedTensorWorkerSlice worker_slice, + std::size_t operating_dim, + uint32_t num_pages_per_packet, + uint32_t worker_slice_index) const +{ + const std::size_t num_commands_expected = this->output_tensor_partition.partition_size - 1; + + auto const& tensor_shape = worker_slice.tensor_shape; + auto const& tensor_slice_shape = worker_slice.tensor_slice_shape; + + auto num_slices = output_tensor_partition.partition_size; + auto start_slice_index = output_tensor_partition.partition_index; + std::int64_t end_slice_index_exclusive = output_tensor_partition.partition_index + 1; + + log_trace(tt::LogOp, "ccl_send_writer start_slice_index = {}", start_slice_index); + log_trace(tt::LogOp, "ccl_send_writer end_slice_index_exclusive = {}", end_slice_index_exclusive); + + // Add the command args + auto const& slices = generate_slice_sequence_on_dim_v2( + tensor_shape, + worker_slice.worker_slice_shape, + worker_slice.worker_slice_offset, + operating_dim, + num_slices, + start_slice_index, + end_slice_index_exclusive, + worker_slice_index + ); + TT_ASSERT(num_commands_expected == slices.size()); + + // If we are on device zero, we send n-1 chunks in ascending order + auto &output_tensor = this->op_config.get_output_tensor(0); + TT_ASSERT(output_tensor.get_legacy_shape().size() == 4, "Only 4D tensors are supported for ccl"); + ttnn::ccl::Shape4D output_tensor_shape = {output_tensor.get_legacy_shape()[0], output_tensor.get_legacy_shape()[1],output_tensor.get_legacy_shape()[2],output_tensor.get_legacy_shape()[3]}; + + std::vector args = { + static_cast(output_tensor.buffer()->address()), + static_cast(slices.size()) + }; + std::size_t logged_arg_idx = 0; + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: buffer_address = {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: num_commands = {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + + std::ranges::copy(std::vector{num_pages_per_packet}, std::back_inserter(args)); + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: pages_per_packet {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + + std::ranges::copy(std::vector{this->op_config.get_page_size()}, std::back_inserter(args)); + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: page_size {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + + auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_runtime_args(this->device, output_tensor); + std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); + for (auto const& arg : addr_gen_rt_args) { + log_trace(tt::LogOp, "ccl_send_writer arg[{}]: addr_gen_rt_args[] {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; + } + + log_trace(tt::LogOp, "ccl_send_writer Generating {} ccl send commands", slices.size()); + emit_ccl_send_slice_sequence_commands(slices, args); + + log_trace(tt::LogOp, "ccl_send_writer Sender Worker has {} RT Args: {}", args.size(), args); + + return args; +} + +std::vector CCLWorkerArgBuilder::generate_sender_reader_kernel_ct_args() const +{ + std::vector args = { + static_cast(this->op_config.get_input_tensor(0).memory_config().memory_layout), // tensor memory layout + static_cast(this->op_config.get_input_tensor(0).buffer()->buffer_type()), // buffer type + static_cast(this->op_config.get_input_tensor(0).layout()), // page layout + static_cast(tt::CB::c_in0) // cb_id + }; + + auto const& input_tensor = this->op_config.get_input_tensor(0); + auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_compile_time_args(input_tensor); + std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); + + return args; +} + +std::vector CCLWorkerArgBuilder::generate_sender_writer_kernel_ct_args() const +{ + std::vector args = { + static_cast(this->op_config.get_output_tensor(0).memory_config().memory_layout), // tensor memory layout + static_cast(this->op_config.get_output_tensor(0).buffer()->buffer_type()), // buffer type + static_cast(this->op_config.get_output_tensor(0).layout()), // page layout + static_cast(tt::CB::c_in0) // cb_id + }; + + auto const& output_tensor = this->op_config.get_output_tensor(0); + auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_compile_time_args(output_tensor); + std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); + + return args; +} + +} // namespace worker_detail +} // namespace ccl +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp new file mode 100644 index 00000000000..b5893b2c486 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/cpp/ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/ccl_common.hpp" + +#include + +namespace tt { +namespace tt_metal { +inline namespace v0 { + +// Forward declarations +class Device; + +} // namespace v0 +} // namespace tt_metal +} // namespace tt + +namespace ttnn { +namespace ccl { +class WorkerEdmInterfaceArgs; + +namespace worker_detail { + +void emit_ccl_send_slice_sequence_commands(std::vector const& slices, std::vector& args_out); + +struct CCLWorkerArgBuilder { + CCLWorkerArgBuilder ( + tt::tt_metal::Device const* device, + ttnn::ccl::CCLOpConfig const& op_config, + ttnn::ccl::TensorPartition const& input_tensor_partition, + ttnn::ccl::TensorPartition const& output_tensor_partition, + std::size_t operating_dim); + + std::vector generate_sender_reader_kernel_rt_args( + ttnn::ccl::InterleavedTensorWorkerSlice worker_slice, + std::size_t operating_dim, + uint32_t num_pages_per_packet, + uint32_t worker_slice_index) const; + + std::vector generate_sender_writer_kernel_rt_args( + ttnn::ccl::InterleavedTensorWorkerSlice worker_slice, + std::size_t operating_dim, + uint32_t num_pages_per_packet, + uint32_t worker_slice_index) const; + + std::vector generate_sender_reader_kernel_ct_args() const; + + std::vector generate_sender_writer_kernel_ct_args() const; + + tt::tt_metal::Device const*device; + ttnn::ccl::TensorPartition const input_tensor_partition; + ttnn::ccl::TensorPartition const output_tensor_partition; + ttnn::ccl::CCLOpConfig const op_config; + std::size_t operating_dim; + bool src_is_dram; + bool dst_is_dram; +}; + +} // namespace worker_detail +} // namespace ccl +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp new file mode 100644 index 00000000000..72e183fe734 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader.cpp @@ -0,0 +1,308 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" +#include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "debug/dprint.h" +#include "ttnn/cpp/ttnn/tensor/enum_types.hpp" +#include + +using ttnn::ccl::coord_t; +// For the future +using address_t = uint32_t; + +using ttnn::ccl::Shape4D; +using tt::tt_metal::TensorMemoryLayout; +using shape_t = Shape4D; + +void dprint(ttnn::ccl::cmd::CclCommandTensor const& command_tensor) { + DPRINT << "\ttensor_shape.w: " << (uint32_t)command_tensor.tensor_shape.w << "\n"; + DPRINT << "\ttensor_shape.z: " << (uint32_t)command_tensor.tensor_shape.z << "\n"; + DPRINT << "\ttensor_shape.y: " << (uint32_t)command_tensor.tensor_shape.y << "\n"; + DPRINT << "\ttensor_shape.x: " << (uint32_t)command_tensor.tensor_shape.x << "\n"; + DPRINT << "\ttensor_slice_shape.w: " << (uint32_t)command_tensor.tensor_slice_shape.w << "\n"; + DPRINT << "\ttensor_slice_shape.z: " << (uint32_t)command_tensor.tensor_slice_shape.z << "\n"; + DPRINT << "\ttensor_slice_shape.y: " << (uint32_t)command_tensor.tensor_slice_shape.y << "\n"; + DPRINT << "\ttensor_slice_shape.x: " << (uint32_t)command_tensor.tensor_slice_shape.x << "\n"; + DPRINT << "\ttensor_slice_offset.w: " << (uint32_t)command_tensor.tensor_slice_offset.w << "\n"; + DPRINT << "\ttensor_slice_offset.z: " << (uint32_t)command_tensor.tensor_slice_offset.z << "\n"; + DPRINT << "\ttensor_slice_offset.y: " << (uint32_t)command_tensor.tensor_slice_offset.y << "\n"; + DPRINT << "\ttensor_slice_offset.x: " << (uint32_t)command_tensor.tensor_slice_offset.x << "\n"; + DPRINT << "\tworker_start_offset_in_slice.w: " << (uint32_t)command_tensor.worker_start_offset_in_slice.w << "\n"; + DPRINT << "\tworker_start_offset_in_slice.z: " << (uint32_t)command_tensor.worker_start_offset_in_slice.z << "\n"; + DPRINT << "\tworker_start_offset_in_slice.y: " << (uint32_t)command_tensor.worker_start_offset_in_slice.y << "\n"; + DPRINT << "\tworker_start_offset_in_slice.x: " << (uint32_t)command_tensor.worker_start_offset_in_slice.x << "\n"; + DPRINT << "\tworker_pages_per_slice: " << (uint32_t)command_tensor.worker_pages_per_slice << "\n"; +} + +void print_tensor_command(uint32_t command_index, ttnn::ccl::cmd::CclCommandTensor const& command_tensor) { +#ifdef DEBUG_PRINT_ENABLED + DPRINT << "cmd[" << (uint32_t)command_index << "]:\n"; + dprint(command_tensor); +#endif +} + +/* + * Convert a flattened worker offset coord value (assumed 0,0,0, worker offset in pages into tensor slice) + * into a 4D coordinate value + */ +inline shape_t worker_wrapped_offset_to_coord(shape_t const& slice_shape, shape_t const& worker_slice_offset) { + static_assert(sizeof(coord_t) == 2 * sizeof(uint32_t), "worker_wrapped_offset_to_coord not updated to work with 4d shape"); + auto const y = worker_slice_offset.x / slice_shape.x; + return shape_t(0, 0, y, worker_slice_offset.x - (y * slice_shape.x)); +} + +std::size_t get_flat_index_from_shape(const Shape4D &shape, const Shape4D &index) { + std::size_t offset = index.x; + std::size_t inner_volume = shape.x; + offset += index.y * inner_volume; + inner_volume *= shape.y; + offset += index.z * inner_volume; + inner_volume *= shape.z; + offset += index.w * inner_volume; + return offset; +} + +using tt::tt_metal::BufferType; +using tt::tt_metal::Layout; + +template +struct source_tensor_addrgen { + static constexpr char name[] = "Uninitialized"; +}; +template +struct source_tensor_addrgen { + static constexpr bool is_dram = buffer_type == tt::tt_metal::BufferType::DRAM; + static constexpr char name[] = "InterleavedAddrGen(default)"; + using type = InterleavedAddrGen; +}; +template +struct source_tensor_addrgen { + static constexpr bool is_dram = buffer_type == tt::tt_metal::BufferType::DRAM; + static constexpr char name[] = "InterleavedAddrGen(Tile)"; + using type = InterleavedAddrGenFast; +}; +template +struct source_tensor_addrgen { + static constexpr char name[] = "WidthSharded"; + using type = tt::tt_metal::address_generators::DefaultWidthShardedAddressGenerator; +}; +template +struct source_tensor_addrgen { + static constexpr char name[] = "HeightSharded"; + using type = tt::tt_metal::address_generators::DefaultHeightShardedAddressGenerator; +}; +template +struct source_tensor_addrgen { + static constexpr char name[] = "BlockSharded"; + using type = tt::tt_metal::address_generators::DefaultBlockShardedAddressGenerator; +}; + + +constexpr bool is_sharded_tensor_layout(tt::tt_metal::TensorMemoryLayout tensor_layout) { + return tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED; +} + +// reader code +template +constexpr Shape4D build_wrapped_row_tensor_slice(T n_pages) { + return Shape4D{1, 1, 1, n_pages}; +} + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr TensorMemoryLayout tensor_layout = static_cast(get_compile_time_arg_val(0)); +constexpr BufferType buffer_type = static_cast(get_compile_time_arg_val(1)); +constexpr Layout page_layout = static_cast(get_compile_time_arg_val(2)); +constexpr uint32_t cb_id = get_compile_time_arg_val(3); + + +#ifdef SHARDED_MEM_LAYOUT +static constexpr bool is_sharded_mode = true; +static constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(5); +static constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(6); +static constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(7); +static constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(8); +static constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(9); +static constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(10); +static constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(11) != 0; +#else +static constexpr bool is_sharded_mode = false; +static constexpr uint32_t input_tensor_shard_grid_height = 0; +static constexpr uint32_t input_tensor_shard_grid_width = 0; +static constexpr uint32_t input_tensor_shard_grid_start_y_logical = 0; +static constexpr uint32_t input_tensor_shard_grid_start_x_logical = 0; +static constexpr uint32_t input_tensor_shard_pages_per_shard_y = 0; +static constexpr uint32_t input_tensor_shard_pages_per_shard_x = 0; +static constexpr bool input_tensor_shard_grid_transposed = false; +#endif + + +template +auto build_source_address_generator(std::size_t &arg_idx, address_t tensor_address, std::size_t page_size, uint32_t cb_id_in0) -> typename source_tensor_addrgen::type { + constexpr bool is_sharded = is_sharded_tensor_layout(tensor_layout); + constexpr bool is_interleaved = tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + constexpr bool is_tile_page_layout = page_layout == tt::tt_metal::Layout::TILE; + constexpr bool is_row_major_layout = page_layout == tt::tt_metal::Layout::ROW_MAJOR; + static_assert(is_sharded || is_interleaved, "Only sharded and interleaved tensor layouts are supported but the unified address generator. A tensor layout not matching TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::BLOCK_SHARDED, or TensorMemoryLayout::INTERLEAVED was specified."); + + using addrgen_type = typename source_tensor_addrgen::type; + + if constexpr (tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + if constexpr (is_row_major_layout) { + return addrgen_type{ + .bank_base_address = tensor_address, .page_size = page_size}; + } else { + return addrgen_type{ + .bank_base_address = tensor_address, .page_size = page_size, .data_format = get_dataformat(cb_id_in0)}; + } + } else if constexpr ( + tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { + size_t input_shard_grid_nrows = get_arg_val(arg_idx++); + const auto * const input_shard_grid_row_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_nrows; + size_t input_shard_grid_ncols = get_arg_val(arg_idx++); + const auto * const input_shard_grid_col_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_ncols; + + return tt::tt_metal::address_generators::build_sharded_addr_gen( + tt::tt_metal::address_generators::HarvestedWormholeWorkerToNocLookup( + input_shard_grid_nrows, + input_shard_grid_row_map, + input_shard_grid_ncols, + input_shard_grid_col_map), + typename tt::tt_metal::address_generators::DeviceShardSpecTypeGetter::type( + input_tensor_shard_pages_per_shard_y, + input_tensor_shard_pages_per_shard_x, + input_tensor_shard_grid_height, + input_tensor_shard_grid_width, + input_tensor_shard_grid_start_y_logical, + input_tensor_shard_grid_start_x_logical, + input_tensor_shard_grid_transposed + ), + page_size, + tensor_address + ); + } else { + ASSERT(false); + } +} + +/* +* CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) dispatch +* implementations depending on those invocation parameters. +*/ +void kernel_main() { + std::size_t arg_idx = 0; + + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + // Load the input tensor spec + address_t tensor_address = get_arg_val(arg_idx++); + address_t num_commands = get_arg_val(arg_idx++); + + // Assuming whole page transmissions (which is the only mode we support at the moment) + // -> however, wanted to call it out here to make it clear that we need to pull this + // out when we start enabling other modes + const uint32_t packet_size_in_pages = get_arg_val(arg_idx++); + const uint32_t page_size = get_arg_val(arg_idx++); + auto tensor_addrgen = build_source_address_generator(arg_idx, tensor_address, page_size, tt::CB::c_in0); + + ttnn::ccl::cmd::CclCommandTensor command_tensor; + + // Don't use CBs because there appears to be a bug if we have the same producer/consumer core to a given CB + // Instead, open up the CB and use it as a raw scratch space6 + const uint32_t local_l1_scratch_buffer_address = get_write_ptr(cb_id); + + // #ifdef DEBUG_PRINT_ENABLED + // DPRINT << "ccl_send_reader has " << (uint32_t)num_commands << " commands" << ENDL(); + // #endif + + for (std::size_t i = 0; i < num_commands; ++i) { + // Generalized would be to get the command header info and then dispatch accordingly - if the command type is singular + // + std::size_t old_arg_idx = arg_idx; + ttnn::ccl::cmd::update_command_tensor(arg_idx, command_tensor); + std::size_t new_arg_idx = arg_idx; + + { + // print_tensor_command(i, command_tensor); + ASSERT(command_tensor.worker_pages_per_slice > 0); + + // CURRENTLY ONLY SUPPORTS WRAPPED TENSOR ITERATION COMMANDS + // Implemented really inefficiently for now - in the future we can do more efficient packing and also change + // the tensor read API to require the information in a more efficient way (less intermediate calculations) + // const shape_t tensor_slice_start_offset = ttnn::ccl::build_from_args(arg_idx); // Should be RT + shape_t valid_worker_slice_shape = build_wrapped_row_tensor_slice(command_tensor.worker_pages_per_slice); // Parametrizable by ct arg + + shape_t const& worker_start_offset_global = worker_wrapped_offset_to_coord(command_tensor.tensor_slice_shape, command_tensor.worker_start_offset_in_slice); + shape_t const& global_offset = command_tensor.tensor_slice_offset + worker_start_offset_global; + + uint32_t curr_tile_id = get_flat_index_from_shape(command_tensor.tensor_shape, global_offset); + + // DPRINT << "valid_worker_slice_shape.w: " << valid_worker_slice_shape.w << ENDL(); + // DPRINT << "valid_worker_slice_shape.z: " << valid_worker_slice_shape.z << ENDL(); + // DPRINT << "valid_worker_slice_shape.y: " << valid_worker_slice_shape.y << ENDL(); + // DPRINT << "valid_worker_slice_shape.x: " << valid_worker_slice_shape.x << ENDL(); + // DPRINT << "global_offset.w: " << global_offset.w << ENDL(); + // DPRINT << "global_offset.z: " << global_offset.z << ENDL(); + // DPRINT << "global_offset.y: " << global_offset.y << ENDL(); + // DPRINT << "global_offset.x: " << global_offset.x << ENDL(); + // DPRINT << "curr_tile_id: " << curr_tile_id << ENDL(); + + uint32_t offset_into_worker_slice = 0; + bool last_page_of_worker = false; + for (uint32_t p = 0; p < command_tensor.worker_pages_per_slice; p += packet_size_in_pages) { + cb_reserve_back(cb_id, packet_size_in_pages); + + uint32_t n_pages = std::min(packet_size_in_pages, command_tensor.worker_pages_per_slice - p); + ASSERT(command_tensor.worker_start_offset_in_slice.w == 0); + ASSERT(command_tensor.worker_start_offset_in_slice.z == 0); + ASSERT(valid_worker_slice_shape.w == 1); + ASSERT(valid_worker_slice_shape.z == 1); + ASSERT(command_tensor.tensor_shape.w == 1); + ASSERT(command_tensor.tensor_shape.z == 1); + ASSERT(command_tensor.tensor_slice_shape.w == 1); + ASSERT(command_tensor.tensor_slice_shape.z == 1); + + // DPRINT << "iter "<< p << " curr_tile_id: " << curr_tile_id << ENDL(); + + read_wrapped_chunk_from_output_tensor_to_address( + curr_tile_id, + offset_into_worker_slice, + ttnn::ccl::coord_t(command_tensor.worker_start_offset_in_slice.x, command_tensor.worker_start_offset_in_slice.y), // Offset into tensor slice + ttnn::ccl::coord_t(valid_worker_slice_shape.x, valid_worker_slice_shape.y), + // In tiles for tile layout + ttnn::ccl::coord_t(command_tensor.tensor_shape.x, command_tensor.tensor_shape.y), + ttnn::ccl::coord_t(command_tensor.tensor_slice_shape.x, command_tensor.tensor_slice_shape.y), + local_l1_scratch_buffer_address, + tensor_addrgen, + n_pages, + page_size, + last_page_of_worker + ); + + cb_push_back(cb_id, packet_size_in_pages); + } + } + } + //////////////////////////////////////////////////////////////////////////////////// + +} diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp new file mode 100644 index 00000000000..a4f6ec3d15f --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_writer.cpp @@ -0,0 +1,308 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types.hpp" +#include "tt_metal/impl/buffers/buffer_constants.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/uops/ccl_command_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_device.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp" +#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp" +#include "debug/dprint.h" +#include "ttnn/cpp/ttnn/tensor/enum_types.hpp" +#include + +using ttnn::ccl::coord_t; +// For the future +using address_t = uint32_t; + +using ttnn::ccl::Shape4D; +using tt::tt_metal::TensorMemoryLayout; +using shape_t = Shape4D; + +void dprint(ttnn::ccl::cmd::CclCommandTensor const& command_tensor) { + DPRINT << "\ttensor_shape.w: " << (uint32_t)command_tensor.tensor_shape.w << "\n"; + DPRINT << "\ttensor_shape.z: " << (uint32_t)command_tensor.tensor_shape.z << "\n"; + DPRINT << "\ttensor_shape.y: " << (uint32_t)command_tensor.tensor_shape.y << "\n"; + DPRINT << "\ttensor_shape.x: " << (uint32_t)command_tensor.tensor_shape.x << "\n"; + DPRINT << "\ttensor_slice_shape.w: " << (uint32_t)command_tensor.tensor_slice_shape.w << "\n"; + DPRINT << "\ttensor_slice_shape.z: " << (uint32_t)command_tensor.tensor_slice_shape.z << "\n"; + DPRINT << "\ttensor_slice_shape.y: " << (uint32_t)command_tensor.tensor_slice_shape.y << "\n"; + DPRINT << "\ttensor_slice_shape.x: " << (uint32_t)command_tensor.tensor_slice_shape.x << "\n"; + DPRINT << "\ttensor_slice_offset.w: " << (uint32_t)command_tensor.tensor_slice_offset.w << "\n"; + DPRINT << "\ttensor_slice_offset.z: " << (uint32_t)command_tensor.tensor_slice_offset.z << "\n"; + DPRINT << "\ttensor_slice_offset.y: " << (uint32_t)command_tensor.tensor_slice_offset.y << "\n"; + DPRINT << "\ttensor_slice_offset.x: " << (uint32_t)command_tensor.tensor_slice_offset.x << "\n"; + DPRINT << "\tworker_start_offset_in_slice.w: " << (uint32_t)command_tensor.worker_start_offset_in_slice.w << "\n"; + DPRINT << "\tworker_start_offset_in_slice.z: " << (uint32_t)command_tensor.worker_start_offset_in_slice.z << "\n"; + DPRINT << "\tworker_start_offset_in_slice.y: " << (uint32_t)command_tensor.worker_start_offset_in_slice.y << "\n"; + DPRINT << "\tworker_start_offset_in_slice.x: " << (uint32_t)command_tensor.worker_start_offset_in_slice.x << "\n"; + DPRINT << "\tworker_pages_per_slice: " << (uint32_t)command_tensor.worker_pages_per_slice << "\n"; +} + +void print_tensor_command(uint32_t command_index, ttnn::ccl::cmd::CclCommandTensor const& command_tensor) { +#ifdef DEBUG_PRINT_ENABLED + DPRINT << "cmd[" << (uint32_t)command_index << "]:\n"; + dprint(command_tensor); +#endif +} + +/* + * Convert a flattened worker offset coord value (assumed 0,0,0, worker offset in pages into tensor slice) + * into a 4D coordinate value + */ +inline shape_t worker_wrapped_offset_to_coord(shape_t const& slice_shape, shape_t const& worker_slice_offset) { + static_assert(sizeof(coord_t) == 2 * sizeof(uint32_t), "worker_wrapped_offset_to_coord not updated to work with 4d shape"); + auto const y = worker_slice_offset.x / slice_shape.x; + return shape_t(0, 0, y, worker_slice_offset.x - (y * slice_shape.x)); +} + +std::size_t get_flat_index_from_shape(const Shape4D &shape, const Shape4D &index) { + std::size_t offset = index.x; + std::size_t inner_volume = shape.x; + offset += index.y * inner_volume; + inner_volume *= shape.y; + offset += index.z * inner_volume; + inner_volume *= shape.z; + offset += index.w * inner_volume; + return offset; +} + +using tt::tt_metal::BufferType; +using tt::tt_metal::Layout; + +template +struct source_tensor_addrgen { + static constexpr char name[] = "Uninitialized"; +}; +template +struct source_tensor_addrgen { + static constexpr bool is_dram = buffer_type == tt::tt_metal::BufferType::DRAM; + static constexpr char name[] = "InterleavedAddrGen(default)"; + using type = InterleavedAddrGen; +}; +template +struct source_tensor_addrgen { + static constexpr bool is_dram = buffer_type == tt::tt_metal::BufferType::DRAM; + static constexpr char name[] = "InterleavedAddrGen(Tile)"; + using type = InterleavedAddrGenFast; +}; +template +struct source_tensor_addrgen { + static constexpr char name[] = "WidthSharded"; + using type = tt::tt_metal::address_generators::DefaultWidthShardedAddressGenerator; +}; +template +struct source_tensor_addrgen { + static constexpr char name[] = "HeightSharded"; + using type = tt::tt_metal::address_generators::DefaultHeightShardedAddressGenerator; +}; +template +struct source_tensor_addrgen { + static constexpr char name[] = "BlockSharded"; + using type = tt::tt_metal::address_generators::DefaultBlockShardedAddressGenerator; +}; + + +constexpr bool is_sharded_tensor_layout(tt::tt_metal::TensorMemoryLayout tensor_layout) { + return tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED; +} + +// reader code +template +constexpr Shape4D build_wrapped_row_tensor_slice(T n_pages) { + return Shape4D{1, 1, 1, n_pages}; +} + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr TensorMemoryLayout tensor_layout = static_cast(get_compile_time_arg_val(0)); +constexpr BufferType buffer_type = static_cast(get_compile_time_arg_val(1)); +constexpr Layout page_layout = static_cast(get_compile_time_arg_val(2)); +constexpr uint32_t cb_id = get_compile_time_arg_val(3); + + +#ifdef SHARDED_MEM_LAYOUT +static constexpr bool is_sharded_mode = true; +static constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(5); +static constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(6); +static constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(7); +static constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(8); +static constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(9); +static constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(10); +static constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(11) != 0; +#else +static constexpr bool is_sharded_mode = false; +static constexpr uint32_t input_tensor_shard_grid_height = 0; +static constexpr uint32_t input_tensor_shard_grid_width = 0; +static constexpr uint32_t input_tensor_shard_grid_start_y_logical = 0; +static constexpr uint32_t input_tensor_shard_grid_start_x_logical = 0; +static constexpr uint32_t input_tensor_shard_pages_per_shard_y = 0; +static constexpr uint32_t input_tensor_shard_pages_per_shard_x = 0; +static constexpr bool input_tensor_shard_grid_transposed = false; +#endif + + +template +auto build_source_address_generator(std::size_t &arg_idx, address_t tensor_address, std::size_t page_size, uint32_t cb_id_in0) -> typename source_tensor_addrgen::type { + constexpr bool is_sharded = is_sharded_tensor_layout(tensor_layout); + constexpr bool is_interleaved = tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + constexpr bool is_tile_page_layout = page_layout == tt::tt_metal::Layout::TILE; + constexpr bool is_row_major_layout = page_layout == tt::tt_metal::Layout::ROW_MAJOR; + static_assert(is_sharded || is_interleaved, "Only sharded and interleaved tensor layouts are supported but the unified address generator. A tensor layout not matching TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::HEIGHT_SHARDED, TensorMemoryLayout::BLOCK_SHARDED, or TensorMemoryLayout::INTERLEAVED was specified."); + + using addrgen_type = typename source_tensor_addrgen::type; + + if constexpr (tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { + if constexpr (is_row_major_layout) { + return addrgen_type{ + .bank_base_address = tensor_address, .page_size = page_size}; + } else { + return addrgen_type{ + .bank_base_address = tensor_address, .page_size = page_size, .data_format = get_dataformat(cb_id_in0)}; + } + } else if constexpr ( + tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || + tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { + size_t input_shard_grid_nrows = get_arg_val(arg_idx++); + const auto * const input_shard_grid_row_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_nrows; + size_t input_shard_grid_ncols = get_arg_val(arg_idx++); + const auto * const input_shard_grid_col_map = reinterpret_cast(get_arg_addr(arg_idx)); + arg_idx += input_shard_grid_ncols; + + return tt::tt_metal::address_generators::build_sharded_addr_gen( + tt::tt_metal::address_generators::HarvestedWormholeWorkerToNocLookup( + input_shard_grid_nrows, + input_shard_grid_row_map, + input_shard_grid_ncols, + input_shard_grid_col_map), + typename tt::tt_metal::address_generators::DeviceShardSpecTypeGetter::type( + input_tensor_shard_pages_per_shard_y, + input_tensor_shard_pages_per_shard_x, + input_tensor_shard_grid_height, + input_tensor_shard_grid_width, + input_tensor_shard_grid_start_y_logical, + input_tensor_shard_grid_start_x_logical, + input_tensor_shard_grid_transposed + ), + page_size, + tensor_address + ); + } else { + ASSERT(false); + } +} + +/* +* CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) dispatch +* implementations depending on those invocation parameters. +*/ +void kernel_main() { + std::size_t arg_idx = 0; + + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + // Load the input tensor spec + address_t dest_address = get_arg_val(arg_idx++); + address_t num_commands = get_arg_val(arg_idx++); + + // Assuming whole page transmissions (which is the only mode we support at the moment) + // -> however, wanted to call it out here to make it clear that we need to pull this + // out when we start enabling other modes + const uint32_t packet_size_in_pages = get_arg_val(arg_idx++); + const uint32_t page_size = get_arg_val(arg_idx++); + auto tensor_addrgen = build_source_address_generator(arg_idx, dest_address, page_size, tt::CB::c_in0); + + ttnn::ccl::cmd::CclCommandTensor command_tensor; + + // Don't use CBs because there appears to be a bug if we have the same producer/consumer core to a given CB + // Instead, open up the CB and use it as a raw scratch space6 + cb_reserve_back(cb_id, packet_size_in_pages); + + #ifdef DEBUG_PRINT_ENABLED + // DPRINT << "ccl_send_writer has " << (uint32_t)num_commands << " commands" << ENDL(); + #endif + + for (std::size_t i = 0; i < num_commands; ++i) { + // Generalized would be to get the command header info and then dispatch accordingly - if the command type is singular + // + std::size_t old_arg_idx = arg_idx; + ttnn::ccl::cmd::update_command_tensor(arg_idx, command_tensor); + std::size_t new_arg_idx = arg_idx; + + { + // print_tensor_command(i, command_tensor); + ASSERT(command_tensor.worker_pages_per_slice > 0); + + // CURRENTLY ONLY SUPPORTS WRAPPED TENSOR ITERATION COMMANDS + // Implemented really inefficiently for now - in the future we can do more efficient packing and also change + // the tensor read API to require the information in a more efficient way (less intermediate calculations) + // const shape_t tensor_slice_start_offset = ttnn::ccl::build_from_args(arg_idx); // Should be RT + shape_t valid_worker_slice_shape = build_wrapped_row_tensor_slice(command_tensor.worker_pages_per_slice); // Parametrizable by ct arg + + shape_t const& worker_start_offset_global = worker_wrapped_offset_to_coord(command_tensor.tensor_slice_shape, command_tensor.worker_start_offset_in_slice); + shape_t const& global_offset = command_tensor.tensor_slice_offset + worker_start_offset_global; + + uint32_t curr_tile_id = get_flat_index_from_shape(command_tensor.tensor_shape, global_offset); + + // DPRINT << "valid_worker_slice_shape.w: " << valid_worker_slice_shape.w << ENDL(); + // DPRINT << "valid_worker_slice_shape.z: " << valid_worker_slice_shape.z << ENDL(); + // DPRINT << "valid_worker_slice_shape.y: " << valid_worker_slice_shape.y << ENDL(); + // DPRINT << "valid_worker_slice_shape.x: " << valid_worker_slice_shape.x << ENDL(); + // DPRINT << "global_offset.w: " << global_offset.w << ENDL(); + // DPRINT << "global_offset.z: " << global_offset.z << ENDL(); + // DPRINT << "global_offset.y: " << global_offset.y << ENDL(); + // DPRINT << "global_offset.x: " << global_offset.x << ENDL(); + // DPRINT << "curr_tile_id: " << curr_tile_id << ENDL(); + + uint32_t offset_into_worker_slice = 0; + bool last_page_of_worker = false; + for (uint32_t p = 0; p < command_tensor.worker_pages_per_slice; p += packet_size_in_pages) { + uint32_t n_pages = std::min(packet_size_in_pages, command_tensor.worker_pages_per_slice - p); + ASSERT(command_tensor.worker_start_offset_in_slice.w == 0); + ASSERT(command_tensor.worker_start_offset_in_slice.z == 0); + ASSERT(valid_worker_slice_shape.w == 1); + ASSERT(valid_worker_slice_shape.z == 1); + ASSERT(command_tensor.tensor_shape.w == 1); + ASSERT(command_tensor.tensor_shape.z == 1); + ASSERT(command_tensor.tensor_slice_shape.w == 1); + ASSERT(command_tensor.tensor_slice_shape.z == 1); + + // cb_wait_front(cb_id, packet_size_in_pages); + + // DPRINT << "iter "<< p << " curr_tile_id: " << curr_tile_id << ENDL(); + + write_wrapped_chunk( + curr_tile_id, + offset_into_worker_slice, + ttnn::ccl::coord_t(command_tensor.worker_start_offset_in_slice.x, command_tensor.worker_start_offset_in_slice.y), // Offset into tensor slice + ttnn::ccl::coord_t(valid_worker_slice_shape.x, valid_worker_slice_shape.y), + // In tiles for tile layout + ttnn::ccl::coord_t(command_tensor.tensor_shape.x, command_tensor.tensor_shape.y), + ttnn::ccl::coord_t(command_tensor.tensor_slice_shape.x, command_tensor.tensor_slice_shape.y), + cb_id, + tensor_addrgen, + n_pages, + page_size, + last_page_of_worker + ); + // // build headers and write to the output cb + // cb_pop_front(cb_id, packet_size_in_pages); + + } + } + } + //////////////////////////////////////////////////////////////////////////////////// +}