Skip to content

Commit

Permalink
#15060: host side code support for new all gather
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Nov 14, 2024
1 parent 7b6c0a6 commit 3ea9beb
Show file tree
Hide file tree
Showing 12 changed files with 1,776 additions and 2 deletions.
232 changes: 232 additions & 0 deletions tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.utility_functions import skip_for_grayskull


def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout):
if layout == ttnn.ROW_MAJOR_LAYOUT and input_dtype == ttnn.bfloat8_b:
return True, "Invalid combination"

if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0):
return True, "Unsupported test case"

## Check that we can readback results
fast_dispatch_page_size_limit = 55 * 1024
elem_size = 2 if input_dtype == ttnn.bfloat16 else 1
if layout == ttnn.ROW_MAJOR_LAYOUT and (input_shape[dim] * elem_size) > fast_dispatch_page_size_limit:
# Fast dispatch currently can't breakup readback of large pages into multiple smaller pages and is
# limited to ~55K pages.
return True, "Fast dispatch can't support reading back this page size in one shot"

# Check that we can fit in L1 (if L1 config)
tensor_size_bytes = elem_size
for i in input_shape:
tensor_size_bytes *= i
num_l1_banks = 64
if mem_config.buffer_type == ttnn.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024:
return True, "L1 buffer can't support large tensor sizes"

# Check that each chip has a non-zero amount of data available
min_sized_chunks_on_dim = input_shape[dim]
if dim == 3:
min_sized_chunks_on_dim //= 32
if dim == 2:
if layout == ttnn.TILE_LAYOUT:
min_sized_chunks_on_dim //= 32
if min_sized_chunks_on_dim < num_devices:
return (
True,
f"Input shape {input_shape} incompatible with {num_devices} on dim {dim} because some chips will have no tensor",
)

if input_shape == [8, 8, 256, 384] and dim == 1 and layout == ttnn.TILE_LAYOUT and input_dtype == ttnn.bfloat8_b:
return True, "Known failure"

return False, ""


def run_all_gather_impl(
mesh_device,
num_devices,
output_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_topology,
num_iters=1,
enable_async=False,
trace_mode=False,
rand_tensor=True,
):
if num_iters < 1:
pytest.fail("num_iters must be >= 1")
# Use Async mode based on test input config
mesh_device.enable_async(enable_async)

if enable_async:
logger.info(f"Using Async Mode for All Gather Op Dispatch")

logger.info(f"Output shape: {output_shape}")
logger.info(f"dim: {dim}")

if rand_tensor:
output_tensor = torch.rand(output_shape).bfloat16()
else:
output_tensor = torch.zeros(output_shape)
tile_id = 1
for w in range(output_shape[0]):
for z in range(output_shape[1]):
for y in range(0, output_shape[2], 32):
for x in range(0, output_shape[3], 32):
output_tensor[w, z, y : y + 32, x : x + 32] = tile_id
tile_id += 1

input_tensors = torch.chunk(output_tensor, num_devices, dim)
tt_input_tensors = []
for i, t in enumerate(input_tensors):
tt_input_tensors.append(ttnn.Tensor(t, input_dtype).to(layout).to(mesh_device.get_devices()[i], mem_config))
logger.info(f"using device {mesh_device.get_devices()[i].id()}")

input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors)
if trace_mode:
tt_out_tensor = run_with_trace(
mesh_device,
all_gather_topology,
input_tensor_mesh,
dim,
num_links,
mem_config,
)
else:
for i in range(num_iters):
tt_out_tensor = ttnn.all_gather(
input_tensor_mesh, dim, num_links=num_links, memory_config=mem_config, topology=all_gather_topology
)

for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
logger.info(f"Done iteration {i}")

for i, t in enumerate(ttnn.get_device_tensors(tt_out_tensor)):
tt_output_tensor = t.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
logger.info(f"Checking for device {t.device().id()}")

# breakpoint()

# for current non-edm version of all gather only:
chunked_output_tensor = torch.chunk(tt_output_tensor, num_devices, dim)

if input_dtype == ttnn.bfloat16:
# eq, output = comp_equal(tt_output_tensor, output_tensor)
eq, output = comp_equal(chunked_output_tensor[i], input_tensors[i])
else:
# eq, output = comp_pcc(tt_output_tensor, output_tensor)
eq, output = comp_pcc(chunked_output_tensor[i], input_tensors[i])
if not eq:
logger.error(f"output mismatch for tensor {i}")
assert eq, f"{i} FAILED: {output}"


# Enumerate the post-commit cases explicitly
@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links, output_shape, dim, layout",
[
# Known errors
# - double/tripple buffers in cb not working
# (4, 2, [4, 1, 256, 32], 0, ttnn.TILE_LAYOUT), # failed: device not connected # https://github.com/tenstorrent/tt-metal/issues/9686
(2, 1, [1, 1, 32, 256], 3, ttnn.TILE_LAYOUT),
(2, 1, [1, 1, 64, 256], 2, ttnn.TILE_LAYOUT),
(8, 1, [8, 1, 256, 32], 0, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
(8, 1, [1, 8, 256, 32], 1, ttnn.TILE_LAYOUT),
(2, 2, [1, 1, 32, 256], 3, ttnn.TILE_LAYOUT),
(2, 2, [1, 1, 64, 256], 2, ttnn.TILE_LAYOUT),
(2, 2, [1, 1, 32, 320], 3, ttnn.TILE_LAYOUT),
(2, 1, [1, 1, 32, 320], 3, ttnn.TILE_LAYOUT),
# (4, 3, [1, 1, 32, 16384 * 4], 3, ttnn.TILE_LAYOUT), # failed: device not connected
(8, 4, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT),
(8, 3, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT),
(8, 2, [1, 8, 32, 2304], 1, ttnn.TILE_LAYOUT),
# untested cases
# (4, 2, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
# (4, 2, [4, 1, 256, 32], 0, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
# (8, 1, [8, 1, 256, 32], 0, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
# (8, 1, [1, 1, 32, 16384], 3, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
# (4, 2, [1, 1, 32, 32768], 3, ttnn.ROW_MAJOR_LAYOUT), # https://github.com/tenstorrent/tt-metal/issues/9686
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
# ttnn.bfloat8_b, # https://github.com/tenstorrent/tt-metal/issues/9686
],
)
@pytest.mark.parametrize(
"mem_config",
[
ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), # https://github.com/tenstorrent/tt-metal/issues/9686
# ttnn.MemoryConfig(buffer_type=ttnn.BufferType.L1),
],
)
@pytest.mark.parametrize("num_iters", [1]) # restore to 500: https://github.com/tenstorrent/tt-metal/issues/9686
@pytest.mark.parametrize("enable_async", [True])
def test_all_gather(
t3k_mesh_device,
# pcie_mesh_device,
num_devices,
output_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
num_iters,
use_program_cache,
function_level_defaults,
enable_async,
):
run_all_gather_impl(
t3k_mesh_device,
num_devices,
output_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_topology=ttnn.Topology.Ring,
num_iters=num_iters,
enable_async=enable_async,
rand_tensor=True,
)

run_all_gather_impl(
t3k_mesh_device,
num_devices,
output_shape,
dim,
num_links,
input_dtype,
layout,
mem_config,
use_program_cache,
function_level_defaults,
all_gather_topology=ttnn.Topology.Ring,
num_iters=num_iters,
enable_async=enable_async,
rand_tensor=False,
)
3 changes: 3 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core_new.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp
Expand All @@ -26,6 +27,8 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/uops/ccl_command.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/host/reduce_scatter_full_worker_grid.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ std::vector<Tensor> AllGather::create_output_tensors(const std::vector<Tensor> &
}

operation::ProgramWithCallbacks AllGather::create_program(const std::vector<Tensor> & input_tensors, std::vector<Tensor> &output_tensors) const {
return all_gather_multi_core_with_workers(input_tensors[0], output_tensors[0], this->dim, this->num_links, this->ring_size, this->ring_index, this->receiver_device_id, this->sender_device_id, this->topology, this->user_defined_num_workers, this->user_defined_num_buffers_per_channel);
return all_gather_multi_core_with_workers_new(input_tensors[0], output_tensors[0], this->dim, this->num_links, this->ring_size, this->ring_index, this->receiver_device_id, this->sender_device_id, this->topology, this->user_defined_num_workers, this->user_defined_num_buffers_per_channel);
}

namespace operations {
Expand Down
12 changes: 12 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,18 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
const std::optional<size_t> user_defined_num_buffers_per_channel,
std::optional<experimental::ccl::AllGatherFusedOpSignaler>& fused_op_signaler,
const CoreCoord core_grid_offset = CoreCoord(0, 0));
operation::ProgramWithCallbacks all_gather_multi_core_with_workers_new(
const Tensor& input_tensor,
Tensor& output_tensor,
const uint32_t dim,
const uint32_t num_links,
const uint32_t ring_size,
const uint32_t ring_index,
const std::optional<chip_id_t> receiver_device_id,
const std::optional<chip_id_t> sender_device_id,
ccl::Topology topology,
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel);



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp"
#include "debug/dprint.h"

using ttnn::ccl::ShardType;
using ttnn::ccl::UNINITIALIZED_VALUE_U16;
Expand Down Expand Up @@ -535,7 +536,7 @@ template <typename AddrGen>
FORCE_INLINE void write_wrapped_chunk(
uint32_t& curr_page_idx,
uint32_t& offset_into_worker_slice,
ttnn::ccl::coord_t& offset_worker_slice,
const ttnn::ccl::coord_t& offset_worker_slice,
const ttnn::ccl::coord_t& worker_slice_shape,

// In tiles for tile layout
Expand Down
Loading

0 comments on commit 3ea9beb

Please sign in to comment.