From a828f89ef7b908cee995be7e9da1f163bb427364 Mon Sep 17 00:00:00 2001 From: Abhinav Sarje Date: Wed, 16 Oct 2024 00:00:47 +0000 Subject: [PATCH] #9889: Remove oldversions of sliding window code and pybind cpp functions instead --- .../functional_unet/tt/unet_shallow_ttnn.py | 65 ++- ...configs_for_untilize_with_halo_and_conv.py | 237 -------- .../unit_tests/operations/test_maxpool2d.py | 20 +- ttnn/CMakeLists.txt | 1 + ttnn/cpp/pybind11/operations/__init__.hpp | 4 + .../operations/conv/conv2d/conv2d_pybind.cpp | 57 ++ .../sliding_window/sliding_window_pybind.cpp | 25 + .../sliding_window/sliding_window_pybind.hpp | 11 + ttnn/ttnn/__init__.py | 2 +- ...dow_op_config_generation_and_validation.py | 159 ----- .../conv/sliding_window_op_utils.py | 181 ------ .../operations/conv/tt_py_composite_conv.py | 345 ----------- ttnn/ttnn/operations/conv/tt_py_op.py | 25 - .../conv/tt_py_untilize_with_halo.py | 225 -------- ...h_halo_config_generation_and_validation.py | 543 ------------------ ttnn/ttnn/operations/conv2d.py | 218 +------ 16 files changed, 152 insertions(+), 1966 deletions(-) delete mode 100644 tests/tt_eager/python_api_testing/unit_testing/misc/test_configs_for_untilize_with_halo_and_conv.py create mode 100644 ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.hpp delete mode 100644 ttnn/ttnn/operations/conv/sliding_window_op_config_generation_and_validation.py delete mode 100644 ttnn/ttnn/operations/conv/sliding_window_op_utils.py delete mode 100644 ttnn/ttnn/operations/conv/tt_py_composite_conv.py delete mode 100644 ttnn/ttnn/operations/conv/tt_py_op.py delete mode 100644 ttnn/ttnn/operations/conv/tt_py_untilize_with_halo.py delete mode 100644 ttnn/ttnn/operations/conv/untilize_with_halo_config_generation_and_validation.py diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 8f672f6d259..45b297d0054 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -8,8 +8,6 @@ from typing import List -from ttnn.operations.conv2d import determine_parallel_config, create_sharded_memory_config_from_parallel_config - from models.utility_functions import nearest_32 from ttnn.model_preprocessing import fold_batch_norm2d_into_conv2d, ParameterDict @@ -261,23 +259,26 @@ def __init__( self.should_reshard = should_reshard if self.should_reshard: - parallel_config = determine_parallel_config( - is_1d_systolic=True, + parallel_config = ttnn._ttnn.operations.conv2d.determine_parallel_config( + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, batch_size=self.conv1.batch_size, input_channels=self.conv1.in_channels, output_height=self.conv2.input_height, output_width=self.conv2.input_width, output_channels=self.conv1.out_channels, device=device, + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=True, ) - self.sharded_memory_config = create_sharded_memory_config_from_parallel_config( - tensor_shape=[ - 1, - 1, - self.conv1.input_width * self.conv1.input_height * self.conv1.batch_size, - nearest_32(self.conv1.in_channels), - ], + self.sharded_memory_config = ttnn._ttnn.operations.conv2d.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.conv1.input_width * self.conv1.input_height * self.conv1.batch_size, + nearest_32(self.conv1.in_channels), + ] + ), parallel_config=parallel_config, tile_size=32 if conv1.dtype == ttnn.bfloat8_b else 1, ) @@ -312,23 +313,26 @@ def __init__( self.should_reshard = should_reshard if self.should_reshard: - parallel_config = determine_parallel_config( - is_1d_systolic=True, + parallel_config = ttnn._ttnn.operations.conv2d.determine_parallel_config( + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, batch_size=self.conv1.batch_size, input_channels=self.conv1.in_channels, output_height=self.conv2.input_height, output_width=self.conv2.input_width, output_channels=self.conv1.out_channels, device=device, + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=True, ) - self.sharded_memory_config = create_sharded_memory_config_from_parallel_config( - tensor_shape=[ - 1, - 1, - self.conv1.input_width * self.conv1.input_height * self.conv1.batch_size, - self.conv1.in_channels, - ], + self.sharded_memory_config = ttnn._ttnn.operations.conv2d.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.conv1.input_width * self.conv1.input_height * self.conv1.batch_size, + self.conv1.in_channels, + ] + ), parallel_config=parallel_config, tile_size=32 if conv1.dtype == ttnn.bfloat8_b else 1, ) @@ -437,23 +441,26 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: self.bnc2 = UNetConv2D( parameters.bnc_2, parameters.bnb_2, device, cache=self.conv_cache, mesh_mapper=mesh_mapper ) - bnc_parallel_config = determine_parallel_config( - is_1d_systolic=True, + bnc_parallel_config = ttnn._ttnn.operations.conv2d.determine_parallel_config( + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, batch_size=self.bnc.batch_size, input_channels=self.bnc.in_channels, output_height=self.bnc2.input_height, output_width=self.bnc2.input_width, output_channels=self.bnc.out_channels, device=device, + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=True, ) - self.bnc_sharded_memory_config = create_sharded_memory_config_from_parallel_config( - tensor_shape=[ - 1, - 1, - self.bnc.input_width * self.bnc.input_height * self.bnc.batch_size, - self.bnc.in_channels, - ], + self.bnc_sharded_memory_config = ttnn._ttnn.operations.conv2d.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttnn.Shape( + [ + 1, + 1, + self.bnc.input_width * self.bnc.input_height * self.bnc.batch_size, + self.bnc.in_channels, + ] + ), parallel_config=bnc_parallel_config, tile_size=(32 if self.bnc.conv_config.dtype == ttnn.bfloat8_b else 1), ) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_configs_for_untilize_with_halo_and_conv.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_configs_for_untilize_with_halo_and_conv.py deleted file mode 100644 index 07d75c32528..00000000000 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_configs_for_untilize_with_halo_and_conv.py +++ /dev/null @@ -1,237 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch -import numpy -from loguru import logger - -from ttnn.operations.conv.untilize_with_halo_config_generation_and_validation import ( - trace_conv_to_generate_data_top_left_indices_and_pad_metadata, - validate_input_padded_tensor_and_data_top_left_indices_and_pad_metadata, - decompose_conv_into_shards_and_generate_tensor_metadata, - construct_utwh_output_shards, - validate_utwh_output_shards_and_req_conv_input_shard_start_end, - validate_tensor_metadata, - generate_untilize_with_halo_kernel_configs, - validate_untilize_with_halo_kernel_configs, -) -from ttnn.operations.conv.sliding_window_op_config_generation_and_validation import ( - generate_sliding_window_op_sharded_input_top_left_indices, - validate_conv_sharded_input_top_left_indices, - validate_max_pool_sharded_input_top_left_indices, -) -from ttnn.operations.conv.tt_py_untilize_with_halo import TTPyUntilizeWithHalo -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_allclose_and_pcc, comp_pcc -from tt_lib.utils import _nearest_y - - -# conv params - output_channels, input_channels, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, dilation, groups -@pytest.mark.parametrize( - "conv_params, batch_size, input_chw_shape, num_cores, test_max_pool", - ( - # ((1, 1, 2, 2, 1, 1, 0, 0, 1, 1), 8, (1, 8, 8), 1, False), - # ((1, 1, 2, 2, 1, 1, 0, 0, 1, 1), 8, (1, 8, 8), 2, False), - # ((1, 1, 2, 2, 1, 1, 1, 1, 1, 1), 8, (1, 8, 8), 1, False), - # ((1, 1, 2, 2, 1, 1, 1, 1, 1, 1), 8, (1, 8, 8), 2, False), - # resnet50 s1 convs - ((2, 2, 4, 4, 1, 1, 0, 0, 1, 1), 8, (2, 115, 115), 98, False), # first conv b8 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 8, (1, 56, 56), 98, False), # layer1 b8 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 8, (1, 56, 56), 98, False), # layer1 b8 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 8, (1, 28, 28), 98, False), # layer2 b8 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 8, (1, 14, 14), 10, False), # layer3 b8 - 10 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 8, (1, 7, 7), 7, False), # layer4 b8 - 7 cores for height slicing - ((1, 1, 4, 4, 1, 1, 0, 0, 1, 1), 16, (1, 115, 115), 98, False), # first conv b16 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 16, (1, 56, 56), 98, False), # layer1 b16 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 16, (1, 28, 28), 98, False), # layer2 b16 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 16, (1, 14, 14), 11, False), # layer3 b16 - 11 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 16, (1, 7, 7), 9, False), # layer4 b16 - 9 cores for height slicing - ((1, 1, 4, 4, 1, 1, 0, 0, 1, 1), 20, (1, 115, 115), 98, False), # first conv b16 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 20, (1, 56, 56), 98, False), # layer1 b20 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 20, (1, 28, 28), 98, False), # layer2 b20 - 98 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 20, (1, 14, 14), 12, False), # layer3 b20 - 12 cores for height slicing - ((1, 1, 3, 3, 1, 1, 1, 1, 1, 1), 20, (1, 7, 7), 11, False), # layer4 b20 - 11 cores for height slicing - # resnet50 s2 convs - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 8, (1, 56, 56), 98, False), # layer2 b8 - 98 cores for height slicing - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 8, (1, 28, 28), 10, False), # layer3 b8 - 10 cores for height slicing - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 8, (1, 14, 14), 7, False), # layer4 b8 - 7 cores for height slicing - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 16, (1, 56, 56), 98, False), # layer2 b16 - 98 cores for height slicing - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 16, (1, 28, 28), 11, False), # layer3 b16 - 11 cores for height slicing - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 16, (1, 14, 14), 9, False), # layer3 b16 - 9 cores for height slicing - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 20, (1, 56, 56), 98, False), # layer2 b20 - 98 cores for height slicing - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 20, (1, 28, 28), 12, False), # layer3 b20 - 12 cores for height slicing - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 20, (1, 14, 14), 11, False), # layer3 b20 - 11 cores for height slicing - # resnet50 maxpool - ((2, 2, 3, 3, 2, 2, 1, 1, 1, 1), 8, (2, 112, 112), 98, True), - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 16, (1, 112, 112), 98, True), - ((1, 1, 3, 3, 2, 2, 1, 1, 1, 1), 20, (1, 112, 112), 98, True), - ), -) -def test_generate_all_configs_and_references( - device, conv_params, batch_size, input_chw_shape, num_cores, test_max_pool -): - assert len(conv_params) == 10 - output_channels, input_channels, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, dilation, groups = [ - conv_params[i] for i in range(10) - ] - - torch.set_printoptions(threshold=10000, edgeitems=50, linewidth=400) - - # Construct conv inputs and filters and run pytorch conv for golden reference - # unpadded raw tensor - input_tensor = [] - assert len(input_chw_shape) == 3 - input_c, input_h, input_w = input_chw_shape - assert input_c == input_channels - input_nchw_shape = [batch_size, input_c, input_h, input_w] - input_volume = numpy.prod(input_nchw_shape) - input_nhw_size = batch_size * input_h * input_w - conv_output_h = ((int)((input_h + (2 * pad_h) - filter_h) / stride_h)) + 1 - conv_output_w = ((int)((input_w + (2 * pad_w) - filter_w) / stride_w)) + 1 - conv_output_nhw_size = batch_size * conv_output_h * conv_output_w - - input_size_to_shard_evenly = _nearest_y(input_nhw_size, num_cores * 32) - untilize_with_halo_input_shard_height = (int)(input_size_to_shard_evenly / num_cores) - output_size_to_shard_evenly = _nearest_y(conv_output_nhw_size, num_cores * 32) - conv_output_shard_height = (int)(output_size_to_shard_evenly / num_cores) - - logger.info(f"untilize with halo input shard height={untilize_with_halo_input_shard_height}") - logger.info(f"conv_output_shard_height={conv_output_shard_height}") - - # Initialize tensor with data - # Inserting sequential integer data - for val in range(1, input_volume + 1): - input_tensor.append(val) - input_pyt_tensor = torch.tensor(input_tensor) - # input_pyt_tensor = torch.rand(input_volume, dtype=torch.bfloat16) - input_pyt_tensor = torch.reshape(input_pyt_tensor, input_nchw_shape) - # Initializing filters with all 1s - filter_pyt_tensor = torch.full((output_channels, input_channels, filter_h, filter_w), 1) - # filter_pyt_tensor = torch.rand((output_channels, input_channels, filter_h, filter_w)) - # run conv pytorch - out_golden_pyt_tensor = torch.nn.functional.conv2d( - input_pyt_tensor, filter_pyt_tensor, stride=(stride_h, stride_w), padding=(pad_h, pad_w) - ) - - input_padded_width = input_w + 2 * pad_w - input_padded_height = input_h + 2 * pad_h - # Generate following configs by tracing conv - - logger.info("Trace conv and generate following configs - pad_metadata and data_top_left_indices.") - pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( - conv_params, input_nchw_shape - ) - - logger.info("Generate input tensor") - input_padded_pyt_tensor = torch.nn.functional.pad(input_pyt_tensor, (pad_w, pad_w, pad_h, pad_h), value=0) - input_padded_pyt_tensor = input_padded_pyt_tensor.permute(0, 2, 3, 1) - input_padded_tensor = input_padded_pyt_tensor.reshape(-1).tolist() - # run trace conv reference to validate pad_metadata and data_top_left_indices - logger.info("Validate pad_metadata and data_top_left_indices.") - - validate_input_padded_tensor_and_data_top_left_indices_and_pad_metadata( - input_padded_tensor, - input_nchw_shape, - pad_h, - pad_w, - filter_pyt_tensor, - out_golden_pyt_tensor, - pad_metadata, - data_top_left_indices, - ) - - # Generate more configs - - logger.info( - "Decompose conv into shards and generate the required conv input shard start/end stick indices and tensor metadata." - ) - req_conv_input_shard_start_end, tensor_metadata = decompose_conv_into_shards_and_generate_tensor_metadata( - data_top_left_indices, - pad_metadata, - input_padded_width, - conv_output_shard_height, - untilize_with_halo_input_shard_height, - num_cores, - filter_h, - filter_w, - ) - logger.info("Validate required conv input shard start/end stick indices") - input_nchw_padded_shape = [batch_size, input_c, input_padded_height, input_padded_width] - golden_untilize_with_halo_output_shards = construct_utwh_output_shards( - input_padded_tensor, input_nchw_padded_shape, req_conv_input_shard_start_end - ) - - validate_utwh_output_shards_and_req_conv_input_shard_start_end( - input_nchw_padded_shape, - filter_pyt_tensor, - out_golden_pyt_tensor, - data_top_left_indices, - golden_untilize_with_halo_output_shards, - req_conv_input_shard_start_end, - ) - - logger.info("Validate tensor metadata") - untilize_with_halo_input_shards = validate_tensor_metadata( - input_tensor, - input_nchw_shape, - untilize_with_halo_input_shard_height, - tensor_metadata, - req_conv_input_shard_start_end, - golden_untilize_with_halo_output_shards, - ) - - # Generate and validate the final untilize with halo configs here - logger.info("Generate untilize with halo kernel configs") - ( - padding_config, - local_config, - remote_config, - max_out_nsticks_per_core, - ) = generate_untilize_with_halo_kernel_configs(tensor_metadata, req_conv_input_shard_start_end) - - logger.info("Validate reshards") - validate_untilize_with_halo_kernel_configs( - golden_untilize_with_halo_output_shards, - untilize_with_halo_input_shards, - req_conv_input_shard_start_end, - padding_config, - local_config, - remote_config, - max_out_nsticks_per_core, - ) - - # Generate sliding window op config - - logger.info("Generate sliding window op configs - top left positioned indices for input shards") - sliding_window_op_sharded_input_top_left_indices = generate_sliding_window_op_sharded_input_top_left_indices( - data_top_left_indices, req_conv_input_shard_start_end - ) - - if not test_max_pool: - logger.info("Validate conv_sharded_input_top_left_indices") - validate_conv_sharded_input_top_left_indices( - golden_untilize_with_halo_output_shards, - input_padded_width, - filter_pyt_tensor, - out_golden_pyt_tensor, - sliding_window_op_sharded_input_top_left_indices, - ) - else: - logger.info("Validate pool_sharded_input_top_left_indices") - # run max pool pytorch to get golden output - assert filter_h == filter_w and stride_h == stride_w and pad_h == pad_w - pool_out_golden_pyt_tensor = torch.nn.MaxPool2d( - filter_h, - stride=stride_h, - padding=pad_h, - dilation=1, - return_indices=False, - ceil_mode=False, - )(input_pyt_tensor.float()) - - validate_max_pool_sharded_input_top_left_indices( - golden_untilize_with_halo_output_shards, - input_padded_width, - filter_h, - filter_w, - pool_out_golden_pyt_tensor, - sliding_window_op_sharded_input_top_left_indices, - ) diff --git a/tests/ttnn/unit_tests/operations/test_maxpool2d.py b/tests/ttnn/unit_tests/operations/test_maxpool2d.py index 1abc25a2df1..192d47e2a78 100644 --- a/tests/ttnn/unit_tests/operations/test_maxpool2d.py +++ b/tests/ttnn/unit_tests/operations/test_maxpool2d.py @@ -12,7 +12,6 @@ from tests.ttnn.utils_for_testing import assert_with_pcc import ttnn -from ttnn.operations.conv2d import determine_parallel_config, create_sharded_memory_config_from_parallel_config def run_max_pool( @@ -82,18 +81,19 @@ def run_max_pool( ttact_device = ttnn.to_device(ttact, device) if pre_shard: - parallel_config = determine_parallel_config( - is_1d_systolic=True, + parallel_config = ttnn._ttnn.operations.conv2d.determine_parallel_config( + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, batch_size=in_n, input_channels=in_c, output_height=out_h, output_width=out_w, output_channels=in_c, device=device, + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=False, ) - sharded_memory_config = create_sharded_memory_config_from_parallel_config( - tensor_shape=act_shape, + sharded_memory_config = ttnn._ttnn.operations.conv2d.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttact_device.shape, parallel_config=parallel_config, tile_size=32 if dtype == ttnn.bfloat8_b else 1, ) @@ -470,19 +470,19 @@ def test_pool_core_nondivis( ttact_device = ttnn.to_device(ttact, device) if pre_shard: - parallel_config = determine_parallel_config( - is_1d_systolic=True, + parallel_config = ttnn._ttnn.operations.conv2d.determine_parallel_config( + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, batch_size=in_n, input_channels=in_c, output_height=out_h, output_width=out_w, output_channels=in_c, device=device, + block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, is_out_tiled=True, - config_override=config_override, ) - sharded_memory_config = create_sharded_memory_config_from_parallel_config( - tensor_shape=act_shape, + sharded_memory_config = ttnn._ttnn.operations.conv2d.create_sharded_memory_config_from_parallel_config( + tensor_shape=ttact_device.shape, parallel_config=parallel_config, tile_size=32 if dtype == ttnn.bfloat8_b else 1, ) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index aaafb8300d9..1ee883d7d00 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -320,6 +320,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/halo/halo.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/sliding_window.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/sliding_window_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sliding_window/reference_sliding_window.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/transformer_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/transformer/attention_softmax/attention_softmax.cpp diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 187c1af31ff..8efabee9451 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -13,6 +13,7 @@ #include "ttnn/operations/ccl/all_gather/all_gather_pybind.hpp" #include "ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.hpp" #include "ttnn/operations/conv/conv2d/conv2d_pybind.hpp" +#include "ttnn/operations/sliding_window/sliding_window_pybind.hpp" #include "ttnn/operations/data_movement/data_movement_pybind.hpp" #include "ttnn/operations/eltwise/binary/binary_pybind.hpp" #include "ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp" @@ -104,6 +105,9 @@ void py_module(py::module& module) { auto m_data_movement = module.def_submodule("data_movement", "data_movement operations"); data_movement::py_module(m_data_movement); + auto m_sliding_window = module.def_submodule("sliding_window", "sliding_window operations"); + sliding_window::py_bind_sliding_window(m_sliding_window); + auto m_conv2d = module.def_submodule("conv2d", "conv2d operation"); conv::conv2d::py_bind_conv2d(m_conv2d); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index de9ad49a9eb..2ac21017726 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -7,6 +7,7 @@ #include "ttnn/cpp/pybind11/decorators.hpp" #include "conv2d_pybind.hpp" +#include "ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.hpp" #include "conv2d.hpp" namespace py = pybind11; @@ -196,6 +197,62 @@ void py_bind_conv2d(py::module& module) { py::arg("num_groups"), py::arg("output_dtype").noconvert() = std::nullopt); + module.def( + "determine_parallel_config", + [](const ttnn::TensorMemoryLayout& shard_layout, + uint32_t batch_size, + uint32_t input_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t output_channels, + ttnn::Device* device, + ShardOrientation block_shard_orientation, + bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig { + return ttnn::operations::conv::conv2d::determine_parallel_config( + shard_layout, batch_size, input_channels, output_height, output_width, output_channels, device, block_shard_orientation, is_out_tiled); + }, + py::arg("shard_layout"), + py::arg("batch_size"), + py::arg("input_channels"), + py::arg("output_height"), + py::arg("output_width"), + py::arg("output_channels"), + py::arg("device"), + py::arg("block_shard_orientation"), + py::arg("is_out_tiled") = true); + + module.def( + "determine_parallel_config", + [](const ttnn::TensorMemoryLayout& shard_layout, + uint32_t batch_size, + uint32_t input_channels, + uint32_t output_height, + uint32_t output_width, + uint32_t output_channels, + ttnn::MeshDevice* device, + ShardOrientation block_shard_orientation, + bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig { + return ttnn::operations::conv::conv2d::determine_parallel_config( + shard_layout, batch_size, input_channels, output_height, output_width, output_channels, device, block_shard_orientation, is_out_tiled); + }, + py::arg("shard_layout"), + py::arg("batch_size"), + py::arg("input_channels"), + py::arg("output_height"), + py::arg("output_width"), + py::arg("output_channels"), + py::arg("device"), + py::arg("block_shard_orientation"), + py::arg("is_out_tiled") = true); + + module.def( + "create_sharded_memory_config_from_parallel_config", + &ttnn::operations::conv::conv2d::create_sharded_memory_config_from_parallel_config, + py::arg("tensor_shape"), + py::arg("parallel_config"), + py::arg("tile_size")); + + auto py_conv_config = py::class_(module, "Conv2dConfig"); py_conv_config.def( py::init, std::optional, bool, Layout, bool, bool, bool>(), diff --git a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.cpp b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.cpp new file mode 100644 index 00000000000..3206f2b6483 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.cpp @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/pybind11/decorators.hpp" +#include "sliding_window.hpp" + +namespace py = pybind11; +namespace ttnn::operations::sliding_window { + +void py_bind_sliding_window(py::module& module) { + py::class_(module, "ParallelConfig") + .def( + py::init(), + py::kw_only(), + py::arg("grid"), + py::arg("shard_scheme"), + py::arg("shard_orientation") + ) + .def_readwrite("grid", &ParallelConfig::grid) + .def_readwrite("shard_scheme", &ParallelConfig::shard_scheme) + .def_readwrite("shard_orientation", &ParallelConfig::shard_orientation); +} + +} // namespace ttnn::operations::sliding_window diff --git a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.hpp b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.hpp new file mode 100644 index 00000000000..5afcc8c366d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window_pybind.hpp @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace ttnn::operations::sliding_window { + void py_bind_sliding_window(pybind11::module& module); +} diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index a0e3bf481fa..1ef02bb3f08 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -295,7 +295,7 @@ def prelu(*args, **kwargs): # Alias for leaky_relu. TODO(#8544): implement PReL Topology, ) -from ttnn.operations.conv2d import Conv2dConfig, get_conv_output_dim, get_conv_padded_input_shape_and_mem_config +from ttnn.operations.conv2d import Conv2dConfig, get_conv_padded_input_shape_and_mem_config, get_conv_output_dim from ttnn.operations.pool import avg_pool2d from ttnn.operations.conv1d import Conv1d, Conv1dConfig diff --git a/ttnn/ttnn/operations/conv/sliding_window_op_config_generation_and_validation.py b/ttnn/ttnn/operations/conv/sliding_window_op_config_generation_and_validation.py deleted file mode 100644 index 2cd4976b478..00000000000 --- a/ttnn/ttnn/operations/conv/sliding_window_op_config_generation_and_validation.py +++ /dev/null @@ -1,159 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import numpy - -from tt_lib._internal.comparison_funcs import comp_equal - - -def generate_sliding_window_op_sharded_input_top_left_indices( - data_top_left_indices, conv_shard_start_end, pad_tile=False, pad_last_core=False -): - # data_top_left_indices point to the global input tensor (padded) - # conv_shard_start_end has the start and end index (inclusive) for each shard in the global input tensor - # generate local indices (top left position in the sliding window) in the conv sharded input - conv_sharded_input_top_left_indices = [] - for item in conv_shard_start_end: - conv_output_shard_start, conv_output_shard_end = item[0] - conv_input_shard_start, conv_input_shard_end = item[1] - # sanity check to see that the first element in the input shard is at the top left position of sliding window - assert conv_output_shard_start < len(data_top_left_indices) - assert conv_input_shard_start == data_top_left_indices[conv_output_shard_start] - local_top_left_indices = [ - index - data_top_left_indices[conv_output_shard_start] - for index in data_top_left_indices[conv_output_shard_start : conv_output_shard_end + 1] - ] - conv_sharded_input_top_left_indices.append(local_top_left_indices) - - if pad_tile: - # Pad indices for last core if not equal to other cores - for i in range(len(conv_sharded_input_top_left_indices)): - tile_size = 32 - extend = len(conv_sharded_input_top_left_indices[i]) % tile_size - if extend != 0: - conv_sharded_input_top_left_indices[i].extend([0] * (tile_size - extend)) - - if pad_last_core: - # Pad indices for last core if not equal to other cores - indices_length_per_core = len(conv_sharded_input_top_left_indices[0]) - conv_sharded_input_top_left_indices[-1].extend( - [0] * (indices_length_per_core - len(conv_sharded_input_top_left_indices[-1])) - ) - - return conv_sharded_input_top_left_indices - - -def validate_conv_sharded_input_top_left_indices( - conv_input_shards, - input_padded_width, - filter_pyt_tensor, - out_golden_pyt_tensor, - conv_sharded_input_top_left_indices, -): - filter_k = filter_pyt_tensor.size()[0] - filter_c = filter_pyt_tensor.size()[1] - filter_h = filter_pyt_tensor.size()[2] - filter_w = filter_pyt_tensor.size()[3] - - output_n = out_golden_pyt_tensor.size()[0] - output_c = out_golden_pyt_tensor.size()[1] - output_h = out_golden_pyt_tensor.size()[2] - output_w = out_golden_pyt_tensor.size()[3] - assert output_c == filter_k - # permute filter tensor to be channels last - kchw --> khwc - filter_pyt_tensor_khwc = torch.permute(filter_pyt_tensor, (0, 2, 3, 1)) - # permute output golden pytorch tensor from nchw to cnhw shape - out_golden_pyt_tensor_cnhw = torch.permute(out_golden_pyt_tensor, (1, 0, 2, 3)) - # reshape cnhw to 2d shape = [c, nhw] - out_golden_pyt_tensor_cnhw = torch.reshape(out_golden_pyt_tensor_cnhw, (output_c, output_n * output_h * output_w)) - conv_output_shard_start = 0 - for shard_idx, local_top_left_indices in enumerate(conv_sharded_input_top_left_indices): - conv_shard_output = [] - assert shard_idx < len(conv_input_shards) - conv_input_shard = conv_input_shards[shard_idx] - output_shard_size = len(local_top_left_indices) - conv_output_shard_end = conv_output_shard_start + output_shard_size - for k in range(filter_k): - for local_output_idx, local_input_top_left_idx in enumerate(local_top_left_indices): - start_window_row_idx = local_input_top_left_idx - conv_input_window = [] - for fh in range(filter_h): - for fw in range(filter_w): - assert start_window_row_idx + fw < len(conv_input_shard) - conv_input_window.append(conv_input_shard[start_window_row_idx + fw, :]) - start_window_row_idx += input_padded_width - output_val = numpy.dot( - numpy.array(conv_input_window).flatten(), filter_pyt_tensor_khwc[k, :, :, :].reshape(-1).tolist() - ) - conv_shard_output.append(output_val) - - output_pyt_shard = torch.tensor(conv_shard_output).reshape((filter_k, output_shard_size)) - # compare output shard with golden output pytorch tensor - assert ( - output_pyt_shard.size() - == out_golden_pyt_tensor_cnhw[:, conv_output_shard_start:conv_output_shard_end].size() - ) - # print("out_golden_shard=", out_golden_pyt_tensor.reshape(-1)[conv_output_shard_start : conv_output_shard_end + 1]) - # print("out_shard=", output_pyt_shard) - passing_pcc, output_pcc = comp_equal( - out_golden_pyt_tensor_cnhw[:, conv_output_shard_start:conv_output_shard_end], output_pyt_shard - ) - # print("Passing=", passing_pcc) - # print("Output pcc=", output_pcc) - assert passing_pcc - conv_output_shard_start += output_shard_size - assert conv_output_shard_start == output_n * output_h * output_w - - -def validate_max_pool_sharded_input_top_left_indices( - pool_input_shards, - input_padded_width, - pool_window_h, - pool_window_w, - out_golden_pyt_tensor, - pool_sharded_input_top_left_indices, -): - output_n = out_golden_pyt_tensor.size()[0] - output_c = out_golden_pyt_tensor.size()[1] - output_h = out_golden_pyt_tensor.size()[2] - output_w = out_golden_pyt_tensor.size()[3] - # permute output golden pytorch tensor from nchw to cnhw shape - out_golden_pyt_tensor_cnhw = torch.permute(out_golden_pyt_tensor, (1, 0, 2, 3)) - # reshape cnhw to 2d shape = [c, nhw] - out_golden_pyt_tensor_cnhw = torch.reshape(out_golden_pyt_tensor_cnhw, (output_c, output_n * output_h * output_w)) - pool_output_shard_start = 0 - for shard_idx, local_top_left_indices in enumerate(pool_sharded_input_top_left_indices): - assert shard_idx < len(pool_input_shards) - pool_input_shard = pool_input_shards[shard_idx] - pool_shard_output = [] - output_shard_size = len(local_top_left_indices) - pool_output_shard_end = pool_output_shard_start + output_shard_size - for out_c in range(output_c): - for local_output_idx, local_input_top_left_idx in enumerate(local_top_left_indices): - start_window_row_idx = local_input_top_left_idx - pool_input_window = [] - for fh in range(pool_window_h): - for fw in range(pool_window_w): - assert start_window_row_idx + fw < len(pool_input_shard) - pool_input_window.append(pool_input_shard[start_window_row_idx + fw][out_c]) - start_window_row_idx += input_padded_width - max_val = max(pool_input_window) - pool_shard_output.append(max_val) - output_pyt_shard = torch.tensor(pool_shard_output).reshape((output_c, output_shard_size)) - # compare output shard with golden output pytorch tensor - assert ( - output_pyt_shard.size() - == out_golden_pyt_tensor_cnhw[:, pool_output_shard_start:pool_output_shard_end].size() - ) - # print("out_golden_shard=", out_golden_pyt_tensor.reshape(-1)[conv_output_shard_start : conv_output_shard_end + 1]) - # print("out_shard=", output_pyt_shard) - passing_pcc, output_pcc = comp_equal( - out_golden_pyt_tensor_cnhw[:, pool_output_shard_start:pool_output_shard_end], output_pyt_shard - ) - # print("Passing=", passing_pcc) - # print("Output pcc=", output_pcc) - assert passing_pcc - pool_output_shard_start += output_shard_size - assert pool_output_shard_start == output_n * output_h * output_w diff --git a/ttnn/ttnn/operations/conv/sliding_window_op_utils.py b/ttnn/ttnn/operations/conv/sliding_window_op_utils.py deleted file mode 100644 index 75789fe6fe2..00000000000 --- a/ttnn/ttnn/operations/conv/sliding_window_op_utils.py +++ /dev/null @@ -1,181 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -from collections import namedtuple - -import ttnn -from tt_lib.utils import _nearest_y, roundup - -SlidingWindowOpParams = namedtuple( - "SlidingWindowOpParams", "stride_h stride_w pad_h pad_w window_h window_w batch_size input_h input_w" -) -SlidingWindowOpParamsWithParallelConfig = namedtuple( - "SlidingWindowOpParamsWithParallelConfig", - "stride_h stride_w pad_h pad_w window_h window_w batch_size input_h input_w num_cores_w num_cores_h num_cores_nhw act_reshard_num_cores_nhw", - defaults=[0], -) - - -def get_output_dim(input, window, stride=1, pad=0, dilation=1): - return (input + (2 * pad) - dilation * (window - 1) - 1) // stride + 1 - - -def get_hash_from_sliding_window_op_params(sliding_window_op_params: SlidingWindowOpParamsWithParallelConfig): - return f"{sliding_window_op_params.stride_h}_{sliding_window_op_params.stride_w}_{sliding_window_op_params.pad_h}_{sliding_window_op_params.pad_w}_{sliding_window_op_params.window_h}_{sliding_window_op_params.window_w}_{sliding_window_op_params.batch_size}_{sliding_window_op_params.input_h}_{sliding_window_op_params.input_w}_{sliding_window_op_params.num_cores_w}_{sliding_window_op_params.num_cores_h}_{sliding_window_op_params.num_cores_nhw}_{sliding_window_op_params.act_reshard_num_cores_nhw}" - - -def get_sliding_window_op_output_nhw_shape_( - input_n, input_h, input_w, stride_h, stride_w, pad_h, pad_w, window_h, window_w -): - output_h = ((int)((input_h + (2 * pad_h) - window_h) / stride_h)) + 1 - output_w = ((int)((input_w + (2 * pad_w) - window_w) / stride_w)) + 1 - return [input_n, output_h, output_w] - - -def get_sliding_window_op_input_nhw_shape(sliding_window_op_params): - input_n = sliding_window_op_params.batch_size - input_h = sliding_window_op_params.input_h - input_w = sliding_window_op_params.input_w - return [input_n, input_h, input_w] - - -def get_sliding_window_op_output_nhw_shape(sliding_window_op_params): - stride_h = sliding_window_op_params.stride_h - stride_w = sliding_window_op_params.stride_w - pad_h = sliding_window_op_params.pad_h - pad_w = sliding_window_op_params.pad_w - window_h = sliding_window_op_params.window_h - window_w = sliding_window_op_params.window_w - input_n = sliding_window_op_params.batch_size - input_h = sliding_window_op_params.input_h - input_w = sliding_window_op_params.input_w - output_h = ((int)((input_h + (2 * pad_h) - window_h) / stride_h)) + 1 - output_w = ((int)((input_w + (2 * pad_w) - window_w) / stride_w)) + 1 - return [input_n, output_h, output_w] - - -def get_sliding_window_op_output_shard_nhw_size( - num_cores_nhw, input_n, input_h, input_w, stride_h, stride_w, pad_h, pad_w, window_h, window_w, is_out_tiled=True -): - output_nhw_shape = get_sliding_window_op_output_nhw_shape_( - input_n, input_h, input_w, stride_h, stride_w, pad_h, pad_w, window_h, window_w - ) - if is_out_tiled: - output_nhw_size_to_shard_evenly = _nearest_y(np.prod(output_nhw_shape), num_cores_nhw * 32) - else: - output_nhw_size_to_shard_evenly = _nearest_y(np.prod(output_nhw_shape), num_cores_nhw) - output_shard_nhw_size = (int)(output_nhw_size_to_shard_evenly / num_cores_nhw) - return output_shard_nhw_size - - -def calculate_shard_grid(grid_size, num_cores_nhw, transpose_mcast=True): - num_cores_w, num_cores_h = grid_size - if transpose_mcast: - shard_layout = ( - ttnn.TensorMemoryLayout.BLOCK_SHARDED - if (num_cores_nhw == num_cores_w and num_cores_h > 1) - else ttnn.TensorMemoryLayout.HEIGHT_SHARDED - ) - else: - shard_layout = ( - ttnn.TensorMemoryLayout.BLOCK_SHARDED - if (num_cores_nhw == num_cores_h and num_cores_w > 1) - else ttnn.TensorMemoryLayout.HEIGHT_SHARDED - ) - - if shard_layout == ttnn.TensorMemoryLayout.BLOCK_SHARDED: - core_range = ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_cores_w - 1, num_cores_h - 1)) - shard_grid = ttnn.CoreRangeSet({core_range}) - else: - if num_cores_nhw >= num_cores_w: - num_cores_height_excluding_remainder_last_row = num_cores_nhw // num_cores_w - assert num_cores_h >= num_cores_height_excluding_remainder_last_row - core_range_1 = ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(num_cores_w - 1, num_cores_height_excluding_remainder_last_row - 1), - ) - num_cores_last = num_cores_nhw % num_cores_w - if num_cores_last > 0: - assert num_cores_h == num_cores_height_excluding_remainder_last_row + 1 - core_range_2 = ttnn.CoreRange( - ttnn.CoreCoord(0, num_cores_height_excluding_remainder_last_row), - ttnn.CoreCoord(num_cores_last - 1, num_cores_height_excluding_remainder_last_row), - ) - shard_grid = ttnn.CoreRangeSet({core_range_1, core_range_2}) - else: - assert num_cores_h == num_cores_height_excluding_remainder_last_row - shard_grid = ttnn.CoreRangeSet({core_range_1}) - else: - core_range_1 = ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(num_cores_nhw - 1, 0), - ) - shard_grid = ttnn.CoreRangeSet({core_range_1}) - return shard_grid, shard_layout - - -def calculate_memory_config( - sliding_window_op_params, is_1d_systolic, padded_channels, calc_input=False, tile_size=1, transpose_mcast=True -): - tensor_shape = ( - get_sliding_window_op_input_nhw_shape(sliding_window_op_params) - if calc_input - else get_sliding_window_op_output_nhw_shape(sliding_window_op_params) - ) - tensor_shape.append(padded_channels) - # tensor_shape is [N, H, W, C] - assert len(tensor_shape) == 4 - needs_reshard = calc_input and sliding_window_op_params.act_reshard_num_cores_nhw > 0 - if needs_reshard: - num_cores_nhw = sliding_window_op_params.act_reshard_num_cores_nhw - if is_1d_systolic: - num_cores_w = min(sliding_window_op_params.num_cores_w, num_cores_nhw) - num_cores_h = (num_cores_nhw + num_cores_w - 1) // num_cores_w - else: - if transpose_mcast: - num_cores_w = num_cores_nhw - num_cores_h = sliding_window_op_params.num_cores_h - else: - num_cores_w = sliding_window_op_params.num_cores_h - num_cores_h = num_cores_nhw - else: - num_cores_nhw = sliding_window_op_params.num_cores_nhw - num_cores_w = sliding_window_op_params.num_cores_w - num_cores_h = sliding_window_op_params.num_cores_h - - logical_grid_size = None - grid_size = None - if is_1d_systolic: - logical_grid_size = (num_cores_nhw, 1) - grid_size = (num_cores_w, num_cores_h) - else: - if transpose_mcast: - logical_grid_size = (num_cores_w, num_cores_h) - grid_size = (num_cores_w, num_cores_h) - else: - logical_grid_size = (num_cores_w, num_cores_h) - grid_size = (num_cores_w, num_cores_h) - - shard_grid, shard_layout = calculate_shard_grid(grid_size, num_cores_nhw, transpose_mcast=transpose_mcast) - nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2] - nhw_padded = roundup(nhw_shape, num_cores_nhw * tile_size) - # if (nhw_padded - nhw_shape) > 32: - # breakpoint() - # assert (nhw_padded - nhw_shape) <= 32 - nhw_shard = nhw_padded // num_cores_nhw - if is_1d_systolic or (not is_1d_systolic and transpose_mcast): - assert padded_channels % logical_grid_size[1] == 0 - shard_shape = [nhw_shard, padded_channels // logical_grid_size[1]] - else: - assert padded_channels % logical_grid_size[0] == 0 - shard_shape = [nhw_shard, padded_channels // logical_grid_size[0]] - shard_orientation = ( - ttnn.ShardOrientation.ROW_MAJOR - if is_1d_systolic - else (ttnn.ShardOrientation.COL_MAJOR if transpose_mcast else ttnn.ShardOrientation.ROW_MAJOR) - ) - shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orientation, False) - shard_scheme = ttnn.TensorMemoryLayout.HEIGHT_SHARDED if is_1d_systolic else ttnn.TensorMemoryLayout.BLOCK_SHARDED - return ttnn.MemoryConfig(shard_scheme, ttnn.BufferType.L1, shard_spec) diff --git a/ttnn/ttnn/operations/conv/tt_py_composite_conv.py b/ttnn/ttnn/operations/conv/tt_py_composite_conv.py deleted file mode 100644 index e1f0e97dba7..00000000000 --- a/ttnn/ttnn/operations/conv/tt_py_composite_conv.py +++ /dev/null @@ -1,345 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -from loguru import logger -from typing import List, Union -from ttnn.operations.conv.tt_py_op import TTPyOp -from ttnn.operations.conv.tt_py_untilize_with_halo import TTPyUntilizeWithHalo -from ttnn.operations.conv.untilize_with_halo_config_generation_and_validation import ( - trace_conv_to_generate_data_top_left_indices_and_pad_metadata, - decompose_conv_into_shards_and_generate_tensor_metadata, -) -from ttnn.operations.conv.sliding_window_op_config_generation_and_validation import ( - generate_sliding_window_op_sharded_input_top_left_indices, -) -from ttnn.operations.conv.sliding_window_op_utils import ( - SlidingWindowOpParams, - SlidingWindowOpParamsWithParallelConfig, - get_hash_from_sliding_window_op_params, - calculate_shard_grid, - calculate_memory_config, -) -from tt_lib.utils import ( - _nearest_32, - _nearest_y, - find_closest_largest_divisor, - find_closest_largest_divisor_with_num_padding, - divup, -) - -import ttnn -import torch -import math -import warnings - - -def find_closest_common_largest_divisor(num1: int, num2: int, start_divisor: int): - divisor = start_divisor - while num1 % divisor != 0 or num2 % divisor != 0: - divisor = divisor - 1 - return divisor - - -def determine_largest_subblock_size(block_height, block_width, fp32_accum=False): - subblocks = [ - (2, 4), - (4, 2), - (1, 8), - (8, 1), - (1, 7), - (7, 1), - (2, 3), - (3, 2), - (1, 6), - (6, 1), - (1, 5), - (5, 1), - (2, 2), - (1, 4), - (4, 1), - (1, 3), - (3, 1), - (1, 2), - (2, 1), - (1, 1), - ] - for subblock_height, subblock_width in subblocks: - if fp32_accum and subblock_height * subblock_width > 4: - continue - if block_height % subblock_height == 0 and block_width % subblock_width == 0: - if subblock_width != block_width and subblock_height != 1: - continue - break - return subblock_height, subblock_width - - -def compute_conv_output_height_width(input_height, input_width, sliding_window_op_params): - stride_h = sliding_window_op_params.stride_h - stride_w = sliding_window_op_params.stride_w - pad_h = sliding_window_op_params.pad_h - pad_w = sliding_window_op_params.pad_w - filter_h = sliding_window_op_params.window_h - filter_w = sliding_window_op_params.window_w - output_height = ((int)((input_height - filter_h + 2 * pad_h) / stride_h)) + 1 - output_width = ((int)((input_width - filter_w + 2 * pad_w) / stride_w)) + 1 - return output_height, output_width - - -def determine_parallel_config( - is_1d_systolic, - output_channels, - input_channels, - sliding_window_op_params, - device, - config_override=None, - is_out_tiled=True, - transpose_mcast=True, -): - if config_override is None: - config_override = {} - - batch_size = sliding_window_op_params.batch_size - input_height = sliding_window_op_params.input_h - input_width = sliding_window_op_params.input_w - output_height, output_width = compute_conv_output_height_width(input_height, input_width, sliding_window_op_params) - conv_out_2d_matrix_height = batch_size * output_height * output_width - tile_size = 32 if is_out_tiled else 1 - conv_out_2d_matrix_width_ntiles = divup(output_channels, tile_size) - - compute_with_storage_grid_size = device.compute_with_storage_grid_size() - device_grid_size = (compute_with_storage_grid_size.x, compute_with_storage_grid_size.y) - max_num_cores = device_grid_size[0] * device_grid_size[1] - - if "grid_size" in config_override: - grid_size = config_override["grid_size"] - max_num_cores = grid_size[0] * grid_size[1] - # print(f"max_num_cores: {max_num_cores}") - - def calculate_num_cores_nhw(override): - conv_out_2d_matrix_height_ntiles = divup(conv_out_2d_matrix_height, tile_size) - num_cores_nhw = ( - find_closest_largest_divisor(conv_out_2d_matrix_height_ntiles, max_num_cores) - if is_1d_systolic - else ( - find_closest_largest_divisor_with_num_padding(conv_out_2d_matrix_height_ntiles, device_grid_size[0]) - if transpose_mcast - else find_closest_largest_divisor_with_num_padding( - conv_out_2d_matrix_height_ntiles, device_grid_size[1] - ) - ) - ) - if override is not None and num_cores_nhw != override: - warnings.warn(f"Overriding config: num_cores_nhw from {num_cores_nhw} to user provided config={override}") - num_cores_nhw = override - return num_cores_nhw - - def calculate_grid_size(num_cores_nhw, override): - if is_1d_systolic: - grid_size = [ - device_grid_size[0] if num_cores_nhw >= device_grid_size[0] else num_cores_nhw, - math.ceil(num_cores_nhw / device_grid_size[0]), - ] # for 1d systolic array, grid size is the tightest bound of num_cores_nhw as a rectangle (x,y) - assert ( - num_cores_nhw <= grid_size[0] * grid_size[1] - ), "Error: For 1d systolic conv, num_cores_nhw must be <= grid size" - else: - if transpose_mcast: - grid_size = [ - num_cores_nhw, - find_closest_common_largest_divisor( - conv_out_2d_matrix_width_ntiles, - _nearest_32(input_channels) // 32, - device_grid_size[1], - ), - ] - else: - grid_size = [ - find_closest_common_largest_divisor( - conv_out_2d_matrix_width_ntiles, - _nearest_32(input_channels) // 32, - device_grid_size[0], - ), - num_cores_nhw, - ] - if override is not None and grid_size != override: - warnings.warn(f"Overriding config: grid_size from {grid_size} to user provided config={override}") - grid_size = override - return grid_size - - def calculate_per_core_out_matrix_height_ntiles(logical_grid_x, override): - per_core_out_matrix_height_ntiles = divup(divup(conv_out_2d_matrix_height, logical_grid_x), tile_size) - total_padded_height = per_core_out_matrix_height_ntiles * tile_size * logical_grid_x - assert ( - total_padded_height - conv_out_2d_matrix_height - ) <= per_core_out_matrix_height_ntiles * tile_size, f"total_padded_height({total_padded_height}) - original_height({conv_out_2d_matrix_height}) = {total_padded_height - conv_out_2d_matrix_height}, which exceeds the per-core shard shape height({per_core_out_matrix_height_ntiles * tile_size}). This will result in cores doing work on padded data only which is illegal. This is a result of choosing override num_cores_nhw({num_cores_nhw}) that cannot satisfy this height after tile padding." - if override is not None: - assert override % tile_size == 0, "per_core_out_matrix_height must be divisible by 32 (tile height)" - if (override // tile_size) != per_core_out_matrix_height_ntiles: - warnings.warn( - f"Overriding config: per_core_out_matrix_height from {per_core_out_matrix_height_ntiles * tile_size} to user provided config={override}" - ) - per_core_out_matrix_height_ntiles = override // tile_size - return per_core_out_matrix_height_ntiles - - def calculate_per_core_out_matrix_width_ntiles(logical_grid_y, override): - per_core_out_matrix_width_ntiles = conv_out_2d_matrix_width_ntiles // logical_grid_y - if override is not None: - assert override % 32 == 0, "per_core_weight_matrix_width must be divisible by 32 (tile width)" - if (override // 32) != per_core_out_matrix_width_ntiles: - warnings.warn( - f"Overriding config: per_core_weight_matrix_width from {per_core_out_matrix_width_ntiles * 32} to user provided config={override}" - ) - per_core_out_matrix_width_ntiles = override // 32 - return per_core_out_matrix_width_ntiles - - num_cores_nhw = calculate_num_cores_nhw(config_override.get("num_cores_nhw", None)) - grid_size = calculate_grid_size(num_cores_nhw, config_override.get("grid_size", None)) - logical_grid_x = num_cores_nhw if is_1d_systolic else (grid_size[0] if transpose_mcast else grid_size[1]) - logical_grid_y = 1 if is_1d_systolic else (grid_size[1] if transpose_mcast else grid_size[0]) - per_core_out_matrix_height_ntiles = calculate_per_core_out_matrix_height_ntiles( - logical_grid_x, config_override.get("per_core_out_matrix_height", None) - ) - per_core_out_matrix_width_ntiles = calculate_per_core_out_matrix_width_ntiles( - logical_grid_y, config_override.get("per_core_out_matrix_width", None) - ) - - logger.debug( - f"PARALLEL CONFIG :: {is_1d_systolic} :: {input_channels} :: {output_channels} :: {sliding_window_op_params} :: {config_override} -> {num_cores_nhw} :: {grid_size} :: {per_core_out_matrix_height_ntiles} :: {per_core_out_matrix_width_ntiles}" - ) - - return ttnn.operations.conv2d.OptimizedConvParallelizationConfig( - grid_size=grid_size, - num_cores_nhw=num_cores_nhw, - per_core_out_matrix_height_ntiles=per_core_out_matrix_height_ntiles, - per_core_out_matrix_width_ntiles=per_core_out_matrix_width_ntiles, - ) - - -def determine_per_core_block_config( - is_1d_systolic, - grid_size, - per_core_out_matrix_height_ntiles, - per_core_out_matrix_width_ntiles, - input_channels, - sliding_window_op_params, - use_shallow_conv_variant, - padded_input_channels, - config_override=None, - fp32_accum=False, - transpose_mcast=True, -): - if config_override is None: - config_override = {} - - act_block_h_override = 0 - if "act_block_h" in config_override: - act_block_h_override = config_override["act_block_h"] - assert act_block_h_override % 32 == 0, "act_block_h must be divisible by 32 (tile height)" - act_block_h_ntiles_override = act_block_h_override // 32 - act_block_h_ntiles = ( - act_block_h_ntiles_override if act_block_h_ntiles_override > 0 else per_core_out_matrix_height_ntiles - ) - act_block_w_ntiles = (int)( - ( - _nearest_32(padded_input_channels * sliding_window_op_params.window_w) - if is_1d_systolic - else padded_input_channels - ) - / 32 - ) - if is_1d_systolic: - act_c_num_blocks = 1 - else: - act_c_num_blocks = grid_size.y if transpose_mcast else grid_size.x - assert ( - padded_input_channels % act_c_num_blocks == 0 - ), f"Cannot parallelize conv as a 2d systolic array: Input channels {padded_input_channels} must be divisible by act_c_num_blocks {act_c_num_blocks}." - out_block_h_ntiles = per_core_out_matrix_height_ntiles - assert out_block_h_ntiles % act_block_h_ntiles == 0, "act_block_h must evenly divide out_block_h" - weight_block_w_ntiles = per_core_out_matrix_width_ntiles - out_subblock_h_ntiles, out_subblock_w_ntiles = determine_largest_subblock_size( - act_block_h_ntiles, weight_block_w_ntiles, fp32_accum - ) - if use_shallow_conv_variant and (act_block_h_ntiles // out_subblock_h_ntiles % 2 != 0): - assert is_1d_systolic - # TODO: fix this temporary hack for shallow conv - assert act_block_h_ntiles % 2 == 0 - out_subblock_h_ntiles = act_block_h_ntiles // 2 - assert out_subblock_h_ntiles * out_subblock_w_ntiles <= 8 - - if "act_block_w" in config_override: - act_block_w_override = config_override["act_block_w"] - assert act_block_w_override % 32 == 0, "act_block_w must be divisible by 32 (tile width)" - if (act_block_w_override // 32) != act_block_w_ntiles: - warnings.warn( - f"Overriding config: act_block_w from {act_block_w_ntiles * 32} to user provided config={act_block_w_override}" - ) - act_block_w_ntiles = act_block_w_override // 32 - if "out_subblock_h" in config_override: - assert ( - "out_subblock_w" in config_override - ), "out_subblock_w must also be provided as override config if out_subblock_h is provided" - out_subblock_h_override = config_override["out_subblock_h"] - assert out_subblock_h_override % 32 == 0, "out_subblock_h must be divisible by 32 (tile height)" - out_subblock_w_override = config_override["out_subblock_w"] - assert out_subblock_w_override % 32 == 0, "out_subblock_w must be divisible by 32 (tile width)" - if (out_subblock_h_override // 32) != out_subblock_h_ntiles: - warnings.warn( - f"Overriding config: out_subblock_h from {out_block_h_ntiles * 32} to user provided config={out_subblock_h_override}" - ) - if (out_subblock_w_override // 32) != out_subblock_w_ntiles: - warnings.warn( - f"Overriding config: out_subblock_w from {out_subblock_w_ntiles * 32} to user provided config={out_subblock_w_override}" - ) - if "out_subblock_w" in config_override: - assert ( - "out_subblock_h" in config_override - ), "out_subblock_h must also be provided as override config if out_subblock_w is provided" - conv_blocking_config = ttnn.operations.conv2d.OptimizedConvBlockConfig( - act_block_h_ntiles=act_block_h_ntiles, - act_block_w_ntiles=act_block_w_ntiles, - out_subblock_h_ntiles=out_subblock_h_ntiles, - out_subblock_w_ntiles=out_subblock_w_ntiles, - ) - return conv_blocking_config - - -def determine_1x1conv_as_matmul_config( - conv_parallelization_config, - conv_blocking_config, - use_1d_systolic_array, - fuse_relu, - transpose_mcast=True, -): - if use_1d_systolic_array: - matmul_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( - compute_with_storage_grid_size=conv_parallelization_config.grid_size, - in0_block_w=conv_blocking_config.act_block_w_ntiles, - out_subblock_h=conv_blocking_config.out_subblock_h_ntiles, - out_subblock_w=conv_blocking_config.out_subblock_w_ntiles, - per_core_M=conv_parallelization_config.per_core_out_matrix_height_ntiles, - per_core_N=conv_parallelization_config.per_core_out_matrix_width_ntiles, - fuse_batch=True, - fused_activation=ttnn.UnaryWithParam(ttnn.UnaryOpType.RELU) if fuse_relu else None, - mcast_in0=False, - ) - else: - grid_size_along_c = ( - conv_parallelization_config.grid_size.y if transpose_mcast else conv_parallelization_config.grid_size.x - ) - assert ( - conv_blocking_config.act_block_w_ntiles % grid_size_along_c == 0 - ), "Expected act block width to be divisible by act channel num blocks." - matmul_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( - compute_with_storage_grid_size=conv_parallelization_config.grid_size, - in0_block_w=conv_blocking_config.act_block_w_ntiles - // grid_size_along_c, ##conv_parallelization_config.grid_size.y, - out_subblock_h=conv_blocking_config.out_subblock_h_ntiles, - out_subblock_w=conv_blocking_config.out_subblock_w_ntiles, - per_core_M=conv_parallelization_config.per_core_out_matrix_height_ntiles, - per_core_N=conv_parallelization_config.per_core_out_matrix_width_ntiles, - transpose_mcast=transpose_mcast, - fused_activation=ttnn.UnaryWithParam(ttnn.UnaryOpType.RELU) if fuse_relu else None, - ) - return matmul_config diff --git a/ttnn/ttnn/operations/conv/tt_py_op.py b/ttnn/ttnn/operations/conv/tt_py_op.py deleted file mode 100644 index 9cbc892ca3a..00000000000 --- a/ttnn/ttnn/operations/conv/tt_py_op.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -from abc import ABC, abstractmethod -from typing import List - - -# Base class for OP -class TTPyOp(ABC): - # Generate op config variabes and tensors - def set_op_configs(self): - pass - - # Construct pytorch tensors for op weights and bias. Moves those tensors to device - def set_op_weights_biases(self, weight_tensor: List, bias_tensor: List): - pass - - # Return stats on op's L1 buffers - def get_l1_buffer_stats(self): - pass - - @abstractmethod - def __call__(self): - pass diff --git a/ttnn/ttnn/operations/conv/tt_py_untilize_with_halo.py b/ttnn/ttnn/operations/conv/tt_py_untilize_with_halo.py deleted file mode 100644 index fdfd7570b13..00000000000 --- a/ttnn/ttnn/operations/conv/tt_py_untilize_with_halo.py +++ /dev/null @@ -1,225 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -from typing import List -from ttnn.operations.conv.tt_py_op import TTPyOp -from ttnn.operations.conv.untilize_with_halo_config_generation_and_validation import ( - trace_conv_to_generate_data_top_left_indices_and_pad_metadata, - decompose_conv_into_shards_and_generate_tensor_metadata, - generate_untilize_with_halo_kernel_configs, -) -from tt_lib.utils import _nearest_y -from ttnn.operations.conv.sliding_window_op_utils import ( - SlidingWindowOpParamsWithParallelConfig, - get_hash_from_sliding_window_op_params, - get_sliding_window_op_output_shard_nhw_size, - calculate_shard_grid, -) -import ttnn -import torch -import struct - - -class TTPyUntilizeWithHalo(TTPyOp): - def __init__( - self, - device, - sliding_window_op_params: SlidingWindowOpParamsWithParallelConfig, - halo_reader_patterns_cache, - pad_val=0x0, - is_out_tiled=True, - transpose_mcast=True, - mesh_mapper=None, - ): - self.sliding_window_op_params = sliding_window_op_params - self.device = device - self.mesh_mapper = mesh_mapper - self.transpose_mcast = transpose_mcast - sliding_window_op_params_hash = get_hash_from_sliding_window_op_params(sliding_window_op_params) - self.set_op_configs( - device, sliding_window_op_params_hash, sliding_window_op_params, halo_reader_patterns_cache, is_out_tiled - ) - assert sliding_window_op_params_hash in halo_reader_patterns_cache - utwh_kernel_configs = halo_reader_patterns_cache[sliding_window_op_params_hash] - - def utwh_(activation): - return ttnn.untilize_with_halo_v2( - activation, - utwh_kernel_configs["padding_config"], - utwh_kernel_configs["local_config"], - utwh_kernel_configs["remote_config"], - pad_val=pad_val, - ncores_nhw=self.sliding_window_op_params.num_cores_nhw, - max_out_nsticks_per_core=utwh_kernel_configs["max_out_nsticks_per_core"], - memory_config=utwh_kernel_configs["out_mem_config"], - remote_read=utwh_kernel_configs["remote_read"], - transpose_mcast=self.transpose_mcast, - ) - - self.utwh = utwh_ - - # override abstract methods from base class TTPyOp - def set_op_configs( - self, - device, - sliding_window_op_params_hash, - sliding_window_op_params, - halo_reader_patterns_cache, - is_out_tiled=True, - ): - if sliding_window_op_params_hash not in halo_reader_patterns_cache: - stride_h = sliding_window_op_params.stride_h - stride_w = sliding_window_op_params.stride_w - pad_h = sliding_window_op_params.pad_h - pad_w = sliding_window_op_params.pad_w - window_h = sliding_window_op_params.window_h - window_w = sliding_window_op_params.window_w - input_n = sliding_window_op_params.batch_size - input_h = sliding_window_op_params.input_h - input_w = sliding_window_op_params.input_w - # TODO: Had to add this (should this be shard grid?) - num_cores_w = sliding_window_op_params.num_cores_w - num_cores_h = sliding_window_op_params.num_cores_h - num_cores_nhw = sliding_window_op_params.num_cores_nhw - act_reshard_num_cores_nhw = sliding_window_op_params.act_reshard_num_cores_nhw - assert num_cores_nhw > 0 - # TODO: send input_nhw_shape to generate functions (no need for C) - # output_channels, input_channels, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, dilation, groups - sliding_window_op_all_params = [1, 1, window_h, window_w, stride_h, stride_w, pad_h, pad_w, 1, 1] - input_nchw_shape = [input_n, 1, input_h, input_w] - pad_metadata, data_top_left_indices = trace_conv_to_generate_data_top_left_indices_and_pad_metadata( - sliding_window_op_all_params, input_nchw_shape - ) - sliding_window_output_shard_nhw_size = get_sliding_window_op_output_shard_nhw_size( - num_cores_nhw, - input_n, - input_h, - input_w, - stride_h, - stride_w, - pad_h, - pad_w, - window_h, - window_w, - is_out_tiled, - ) - if is_out_tiled: - untilize_w_halo_input_nhw_size_to_shard_evenly = _nearest_y( - input_n * input_h * input_w, num_cores_nhw * 32 - ) - else: - untilize_w_halo_input_nhw_size_to_shard_evenly = _nearest_y(input_n * input_h * input_w, num_cores_nhw) - untilize_with_halo_input_shard_nhw_size = (int)( - untilize_w_halo_input_nhw_size_to_shard_evenly / num_cores_nhw - ) - req_conv_input_shard_start_end, tensor_metadata = decompose_conv_into_shards_and_generate_tensor_metadata( - data_top_left_indices, - pad_metadata, - input_w + (2 * pad_w), - sliding_window_output_shard_nhw_size, - untilize_with_halo_input_shard_nhw_size, - num_cores_nhw, - window_h, - window_w, - act_reshard_num_cores=act_reshard_num_cores_nhw, - input_nhw_height=input_n * input_h * input_w, - ) - - shard_grid, shard_layout = calculate_shard_grid( - (num_cores_w, num_cores_h), num_cores_nhw, self.transpose_mcast - ) - block_sharding = shard_layout == ttnn.TensorMemoryLayout.BLOCK_SHARDED - - def get_memory_config(shard_shape, buffer_type=ttnn.BufferType.L1_SMALL): - shard_orientation = ( - ttnn.ShardOrientation.ROW_MAJOR - if not block_sharding - else (ttnn.ShardOrientation.COL_MAJOR if self.transpose_mcast else ttnn.ShardOrientation.ROW_MAJOR) - ) - shard_halo = False - shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orientation, shard_halo) - mem_layout = ( - ttnn.TensorMemoryLayout.BLOCK_SHARDED if block_sharding else ttnn.TensorMemoryLayout.HEIGHT_SHARDED - ) - mem_config = ttnn.MemoryConfig(mem_layout, buffer_type, shard_spec) - return mem_config - - def gen_per_core_gather_data_uint16_tensor(config: list): - assert type(config) is list - if block_sharding: - if self.transpose_mcast: - assert len(config) == num_cores_w, f"{len(config)} {num_cores_w}" - else: - assert len(config) == num_cores_h, f"{len(config)} {num_cores_h}" - else: - assert len(config) == num_cores_nhw, f"{len(config)} {num_cores_nhw}" - assert type(config[0]) is list - assert len(config[0]) > 0 - - torch_tensor = torch.tensor(config, dtype=torch.short) - shard_shape = [1, torch_tensor.shape[-1]] - - if block_sharding: - if self.transpose_mcast: - torch_tensor = torch_tensor.repeat(1, num_cores_h) - else: - torch_tensor = torch_tensor.repeat(1, num_cores_w) - - torch_tensor = torch_tensor.unsqueeze(0).unsqueeze(0) - tt_tensor = ttnn.from_torch( - torch_tensor, dtype=ttnn.DataType.UINT16, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=self.mesh_mapper - ) - tt_tensor = tt_tensor.to(device, get_memory_config(shard_shape)) if device is not None else tt_tensor - return tt_tensor - - def core_id_to_physical_coord(core_id): - if block_sharding: - if self.transpose_mcast: - core_coord = ttnn.CoreCoord(core_id, 0) - else: - core_coord = ttnn.CoreCoord(0, core_id) - else: - core_coord = ttnn.CoreCoord(core_id % num_cores_w, core_id // num_cores_w) - - # HACK: Using first device which may have different harvesting than other chips. Logic should be pushed into op - if isinstance(device, ttnn.Device): - worker_core = device.worker_core_from_logical_core(core_coord) - else: - worker_core = device.get_device(0).worker_core_from_logical_core(core_coord) - return (worker_core.x, worker_core.y) - - remote_read = act_reshard_num_cores_nhw > 0 - - ( - padding_config, - local_config, - remote_config, - max_out_nsticks_per_core, - ) = generate_untilize_with_halo_kernel_configs( - tensor_metadata, - req_conv_input_shard_start_end, - core_id_to_physical_coord, - remote_read=remote_read, - ) - - padding_config_tensor = gen_per_core_gather_data_uint16_tensor(padding_config) - local_config_tensor = gen_per_core_gather_data_uint16_tensor(local_config) - remote_config_tensor = gen_per_core_gather_data_uint16_tensor(remote_config) - - # shard_shape[1] filled in with incoming activations in c++ code - out_shard_shape = [untilize_with_halo_input_shard_nhw_size, 0] - - halo_reader_patterns_cache[sliding_window_op_params_hash] = { - "max_out_nsticks_per_core": max_out_nsticks_per_core, - "padding_config": padding_config_tensor, - "local_config": local_config_tensor, - "remote_config": remote_config_tensor, - "out_mem_config": get_memory_config(out_shard_shape, buffer_type=ttnn.BufferType.L1), - "remote_read": remote_read, - } - - return - - def __call__(self, activation): - return self.utwh(activation) diff --git a/ttnn/ttnn/operations/conv/untilize_with_halo_config_generation_and_validation.py b/ttnn/ttnn/operations/conv/untilize_with_halo_config_generation_and_validation.py deleted file mode 100644 index a55eeb39170..00000000000 --- a/ttnn/ttnn/operations/conv/untilize_with_halo_config_generation_and_validation.py +++ /dev/null @@ -1,543 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import numpy as np -from loguru import logger - -from tt_lib._internal.comparison_funcs import comp_equal - - -def trace_conv_to_generate_data_top_left_indices_and_pad_metadata(conv_params, input_nchw_shape): - assert len(conv_params) == 10 - output_channels, input_channels, filter_h, filter_w, stride_h, stride_w, pad_h, pad_w, dilation, groups = [ - conv_params[i] for i in range(10) - ] - assert dilation == 1 and groups == 1 - assert len(input_nchw_shape) == 4 - input_n, input_c, input_h, input_w = [input_nchw_shape[i] for i in range(4)] - - # image 1 data - # 1 2 3 4 5 6 7 8 - # 9 10 11 12 13 14 15 16 - # 17 18 19 20 21 22 23 24 - # 25 26 27 28 29 30 31 32 - # image 2 data - # 33 34 35 36 37 38 39 40 - # 41 42 43 44 45 46 47 48 - # 49 50 51 52 53 54 55 56 - # 57 58 59 60 61 62 63 64 - - # Concatenated image data from above - # Inserted padding above and between and on the sides of the images (pad = 1) - # 0 0 0 0 0 0 0 0 0 0 - # 0 1 2 3 4 5 6 7 8 0 - # 0 9 10 11 12 13 14 15 16 0 - # 0 17 18 19 20 21 22 23 24 0 - # 0 25 26 27 28 29 30 31 32 0 - # 0 0 0 0 0 0 0 0 0 0 - # 0 0 0 0 0 0 0 0 0 0 - # 0 33 34 35 36 37 38 39 40 0 - # 0 41 42 43 44 45 46 47 48 0 - # 0 49 50 51 52 53 54 55 56 0 - # 0 57 58 59 60 61 62 63 64 0 - # 0 0 0 0 0 0 0 0 0 0 - - # We encode above shown padded tensor into pad_metadata (list of boolean - true if padding location) - # pad_meta_data: [true, true, ..., false, ...] - index = 0 - padded_input_h = input_h + (2 * pad_h) - padded_input_w = input_w + (2 * pad_w) - pad_metadata = np.full(input_n * padded_input_h * padded_input_w, False, dtype=bool) - for n in range(input_n): - for h in range(padded_input_h): - for w in range(padded_input_w): - if h < pad_h or h >= (input_h + pad_h) or w < pad_w or w >= (input_w + pad_w): - pad_metadata[index] = True - index += 1 - - # TODO: add support for dilation > 1 - output_h = ((int)((padded_input_h - filter_h) / stride_h)) + 1 - output_w = ((int)((padded_input_w - filter_w) / stride_w)) + 1 - # generate a list of input indices corresponding to the top left position of sliding window - # the index refers to the location in the padded tensor - # data_top_left_indices = [] - index = 0 - data_top_left_indices = np.full(input_n * output_h * output_w, 0, dtype=int) - for n in range(input_n): - for oh in range(output_h): - for ow in range(output_w): - ih = oh * stride_h - iw = ow * stride_w - channel_idx = (n * padded_input_h * padded_input_w) + (ih * padded_input_w) + iw - data_top_left_indices[index] = channel_idx - index += 1 - return pad_metadata.tolist(), data_top_left_indices.tolist() - - -def validate_input_padded_tensor_and_data_top_left_indices_and_pad_metadata( - input_padded_tensor, - input_nchw_shape, - pad_h, - pad_w, - filter_pyt_tensor, - out_golden_pyt_tensor, - pad_metadata, - data_top_left_indices, -): - input_n, input_c, input_h, input_w = input_nchw_shape - filter_k, filter_c, filter_h, filter_w = list(filter_pyt_tensor.size()) - assert input_c == filter_c - - # permute filter tensor to be channels last - kchw --> khwc - filter_pyt_tensor_khwc = torch.permute(filter_pyt_tensor, (0, 2, 3, 1)) - - input_padded_width = input_w + (2 * pad_w) - input_padded_height = input_h + (2 * pad_h) - input_padded_volume = input_n * input_c * input_padded_height * input_padded_width - assert len(input_padded_tensor) == input_padded_volume - input_padded_pyt_tensor_nhwc = torch.tensor(input_padded_tensor).reshape( - [input_n * input_padded_height, input_padded_width, input_c] - ) - output_tensor = [] - # run conv over padded tensor using data_top_left_indices - for k in range(filter_k): - for i in data_top_left_indices: - i_bh = (int)(i / input_padded_width) - i_w = (int)(i % input_padded_width) - output_tensor.append( - torch.dot( - input_padded_pyt_tensor_nhwc[i_bh : i_bh + filter_h, i_w : i_w + filter_w, :].reshape(-1), - filter_pyt_tensor_khwc[k, :, :, :].reshape(-1), - ) - ) - - output_pyt_tensor = torch.tensor(output_tensor) - assert np.prod(output_pyt_tensor.size()) == np.prod(out_golden_pyt_tensor.size()) - # permute output golden pytorch tensor from nchw to cnhw shape - out_golden_pyt_tensor_cnhw = torch.permute(out_golden_pyt_tensor, (1, 0, 2, 3)) - # compare to pytorch - passing_pcc, output_pcc = comp_equal(out_golden_pyt_tensor_cnhw.reshape(-1), output_pyt_tensor.reshape(-1)) - logger.debug(f"Passing={passing_pcc}") - logger.debug(f"Output pcc={output_pcc}") - assert passing_pcc - - -def decompose_conv_into_shards_and_generate_tensor_metadata( - data_top_left_indices, - pad_metadata, - input_padded_w, - conv_output_shard_height, - unpadded_input_shard_height, - num_cores, - filter_h, - filter_w, - act_reshard_num_cores=0, - input_nhw_height=0, -): - req_conv_input_shard_start_end = [] # start and end indices refer to global padded input tensor - conv_output_start_stick = 0 - for core_id in range(num_cores): - if conv_output_start_stick >= len(data_top_left_indices): - print("core_id=", core_id) - print("conv_output_start_stick=", conv_output_start_stick) - print("len(data_top_left_indices)=", len(data_top_left_indices)) - print("conv_output_shard_height=", conv_output_shard_height) - assert conv_output_start_stick < len(data_top_left_indices) - req_conv_input_shard_start_stick = data_top_left_indices[conv_output_start_stick] - conv_output_end_stick = min(conv_output_start_stick + conv_output_shard_height, len(data_top_left_indices)) - 1 - req_conv_input_shard_end_stick = data_top_left_indices[conv_output_end_stick] - halo_with_pad_nsticks = ((filter_h - 1) * input_padded_w) + filter_w - 1 - req_conv_input_shard_end_stick += halo_with_pad_nsticks - req_conv_input_shard_start_end.append( - ( - (conv_output_start_stick, conv_output_end_stick), - (req_conv_input_shard_start_stick, req_conv_input_shard_end_stick), - ) - ) - conv_output_start_stick += conv_output_shard_height - - remap = lambda a, b: (a, b) - if act_reshard_num_cores != 0: - assert input_nhw_height != 0 - assert ( - input_nhw_height % act_reshard_num_cores == 0 - ), f"{input_nhw_height} {act_reshard_num_cores} {num_cores} {unpadded_input_shard_height}" - act_unpadded_input_shard_height = input_nhw_height // act_reshard_num_cores - - def _remap(cid, lid): - idx = cid * unpadded_input_shard_height + lid - return (idx // act_unpadded_input_shard_height, idx % act_unpadded_input_shard_height) - - remap = _remap - - tensor_metadata = [] - unpadded_input_shard_local_idx = 0 - core_id = 0 - for padded_input_tensor_idx in range(len(pad_metadata)): - pad_stick = pad_metadata[padded_input_tensor_idx] - if pad_stick: - tensor_metadata.append((True, 0, 0)) - else: - # sanity check - assert core_id < num_cores, f"{core_id} {num_cores}" - assert unpadded_input_shard_local_idx < unpadded_input_shard_height - tensor_metadata.append((False, *remap(core_id, unpadded_input_shard_local_idx))) - unpadded_input_shard_local_idx += 1 - if unpadded_input_shard_local_idx == unpadded_input_shard_height: - unpadded_input_shard_local_idx = 0 - core_id += 1 - assert len(tensor_metadata) == len(pad_metadata) - return req_conv_input_shard_start_end, tensor_metadata - - -def construct_utwh_output_shards( - # Padded input tensor - input_padded_tensor, - # Padded input tensor shape - input_nchw_padded_shape, - # config to construct shards - req_conv_input_shard_start_end, -): - # reshape input padded tensor to 2d shape - [nhw, c] - assert len(input_nchw_padded_shape) == 4 - input_n, input_c, input_padded_height, input_padded_width = [input_nchw_padded_shape[i] for i in range(4)] - input_2d_padded_tensor = np.reshape( - input_padded_tensor, (input_n * input_padded_height * input_padded_width, input_c) - ) - utwh_output_shards = [] - for item in req_conv_input_shard_start_end: - req_conv_input_shard_start, req_conv_input_shard_end = item[1] - req_conv_input_shard_size = req_conv_input_shard_end - req_conv_input_shard_start + 1 - assert req_conv_input_shard_size <= 65535 # max uint16 value - utwh_output_shards.append(input_2d_padded_tensor[req_conv_input_shard_start : req_conv_input_shard_end + 1, :]) - return utwh_output_shards - - -def validate_utwh_output_shards_and_req_conv_input_shard_start_end( - # Padded input tensor shape - input_nchw_padded_shape, - # Filter pytorch tensor - filter_pyt_tensor, - # Conv golden output tensor to compare against - out_golden_pyt_tensor, - # Input indices corresponding to top left position of sliding window. Used to perform conv operation. - data_top_left_indices, - # validate utwh output shards - utwh_output_shards, - # Validate this config - - req_conv_input_shard_start_end, -): - filter_k = filter_pyt_tensor.size()[0] - filter_c = filter_pyt_tensor.size()[1] - filter_h = filter_pyt_tensor.size()[2] - filter_w = filter_pyt_tensor.size()[3] - - output_n = out_golden_pyt_tensor.size()[0] - output_c = out_golden_pyt_tensor.size()[1] - output_h = out_golden_pyt_tensor.size()[2] - output_w = out_golden_pyt_tensor.size()[3] - assert len(data_top_left_indices) == output_n * output_h * output_w - assert len(input_nchw_padded_shape) == 4 - input_n, input_c, input_padded_height, input_padded_width = [input_nchw_padded_shape[i] for i in range(4)] - assert filter_c == input_c - assert output_n == input_n - assert output_c == filter_k - - # permute filter tensor to be channels last - kchw --> khwc - filter_pyt_tensor_khwc = torch.permute(filter_pyt_tensor, (0, 2, 3, 1)) - - # Perform conv on input shards one at a time, and compare against output. Use data_top_left_indices (global) to perform the conv operation. - output_stick_global = 0 - for input_shard_idx, item in enumerate(req_conv_input_shard_start_end): - assert input_shard_idx < len(utwh_output_shards) - conv_output_shard_start, conv_output_shard_end = item[0] - req_conv_input_shard_start, req_conv_input_shard_end = item[1] - # sanity check that the first item in the shard is at the top left position of sliding window - assert output_stick_global < len(data_top_left_indices) - assert req_conv_input_shard_start == data_top_left_indices[output_stick_global] - output_shard = [] - output_shard_size = conv_output_shard_end - conv_output_shard_start + 1 - for k in range(filter_k): - output_stick = output_stick_global - for o in range(output_shard_size): - assert output_stick < len(data_top_left_indices) - input_top_left_position_stick = data_top_left_indices[output_stick] - assert input_top_left_position_stick >= req_conv_input_shard_start - input_shard_stick_local_idx = input_top_left_position_stick - req_conv_input_shard_start - conv_input_window = [] - for fh in range(filter_h): - for fw in range(filter_w): - assert input_shard_stick_local_idx + fw < len(utwh_output_shards[input_shard_idx]) - conv_input_window.append( - utwh_output_shards[input_shard_idx][input_shard_stick_local_idx + fw, :] - ) - input_shard_stick_local_idx += input_padded_width - output_val = np.dot( - np.array(conv_input_window).flatten(), filter_pyt_tensor_khwc[k, :, :, :].reshape(-1).tolist() - ) - output_shard.append(output_val) - output_stick += 1 - output_stick_global = output_stick - output_pyt_shard = torch.tensor(output_shard).reshape((filter_k, output_shard_size)) - # compare output shard with golden output pytorch tensor - # permute output golden pytorch tensor from nchw to cnhw shape - out_golden_pyt_tensor_cnhw = torch.permute(out_golden_pyt_tensor, (1, 0, 2, 3)) - # reshape cnhw to 2d shape = [c, nhw] - out_golden_pyt_tensor_cnhw = torch.reshape( - out_golden_pyt_tensor_cnhw, (output_c, output_n * output_h * output_w) - ) - assert ( - output_pyt_shard.size() - == out_golden_pyt_tensor_cnhw[:, conv_output_shard_start : conv_output_shard_end + 1].size() - ) - # print("out_golden_shard=", out_golden_pyt_tensor.reshape(-1)[conv_output_shard_start : conv_output_shard_end + 1]) - # print("out_shard=", output_pyt_shard) - passing_pcc, output_pcc = comp_equal( - out_golden_pyt_tensor_cnhw[:, conv_output_shard_start : conv_output_shard_end + 1], output_pyt_shard - ) - # print("Passing=", passing_pcc) - # print("Output pcc=", output_pcc) - assert passing_pcc - - return - - -def validate_tensor_metadata( - input_tensor, - input_nchw_shape, - input_shard_size, - tensor_metadata, - req_conv_input_shard_start_end, - golden_conv_input_shards, -): - # input tensor is unpadded - # Permute input tensor from nchw shape to nhwc shape and reshape to 2d shape - [nhw, c] - assert len(input_nchw_shape) == 4 - input_n, input_c, input_h, input_w = [input_nchw_shape[i] for i in range(4)] - input_nhw_size = input_n * input_h * input_w - input_tensor = np.reshape(input_tensor, input_nchw_shape) - input_tensor_nhwc = np.transpose(input_tensor, (0, 2, 3, 1)) - input_tensor_nhwc = np.reshape(input_tensor_nhwc, (input_n * input_h * input_w, input_c)) - # construct unpadded input tensor shards - unpadded_input_tensor_shards = [] - num_shards = len(req_conv_input_shard_start_end) - unpadded_input_tensor_shard_start = 0 - for i in range(num_shards): - unpadded_input_tensor_shard_end = min(unpadded_input_tensor_shard_start + input_shard_size, input_nhw_size) - assert unpadded_input_tensor_shard_start < len(input_tensor_nhwc) and unpadded_input_tensor_shard_end <= len( - input_tensor_nhwc - ) - unpadded_input_tensor_shards.append( - input_tensor_nhwc[unpadded_input_tensor_shard_start:unpadded_input_tensor_shard_end, :] - ) - unpadded_input_tensor_shard_start += input_shard_size - # Validate tensor_metadata - # Construct conv input shard using tensor_metadata and req_conv_input_shard_start_end indices. Then, compare against golden conv input shards - conv_input_shards = [] - assert len(req_conv_input_shard_start_end) == len(golden_conv_input_shards) - for shard_idx, item in enumerate(req_conv_input_shard_start_end): - conv_input_shard = [] - req_conv_input_shard_start = item[1][0] - req_conv_input_shard_end = item[1][1] - for idx in range(req_conv_input_shard_start, req_conv_input_shard_end + 1): - assert idx < len(tensor_metadata) - pad = tensor_metadata[idx][0] - if pad: - conv_input_shard.append([0] * input_c) - else: - core_id = tensor_metadata[idx][1] - core_local_idx = tensor_metadata[idx][2] - assert core_id < len(unpadded_input_tensor_shards) - assert core_local_idx < len(unpadded_input_tensor_shards[core_id]) - conv_input_shard.append(unpadded_input_tensor_shards[core_id][core_local_idx, :]) - assert (conv_input_shard == golden_conv_input_shards[shard_idx]).all() - return unpadded_input_tensor_shards - - -# Makes all sublists the same length, optionally tile aligns too -def align_up_2d_python_list(list2d: list, extend_value, align_granularity=0): - assert type(list2d) is list - if len(list2d) == 0: - return list2d - assert type(list2d[0]) is list - max_len = 0 - for l in list2d: - max_len = max(len(l), max_len) - if align_granularity > 0: - align_amount = max_len % align_granularity - if align_amount > 0: - max_len += align_granularity - align_amount - for l in list2d: - extend_amount = max_len - len(l) - if extend_amount > 0: - l.extend([extend_value] * extend_amount) - - -def generate_untilize_with_halo_kernel_configs( - tensor_metadata: list, - resharded_start_and_end: list, - core_id_to_physical_coord=lambda core_id: (0, core_id), - remote_read=False, -): - ncores = len(resharded_start_and_end) - - per_core_gather_data = {} - pad_local = 0xFFFF # uint16_t max index means pad - - def run_length_encode(l, src, dst, is_pad): - if len(l) > 0: - src_start, dst_start, length = l[-3], l[-2], l[-1] - # src index is always 0 if is_pad, so we only need to RLE the dst - if (src == (src_start + length) or is_pad) and dst == (dst_start + length): - l[-1] = length + 1 - return False - l.extend([src, dst, 1]) - return True - - ## NOTE: assuming the core_id's are contiguous - for core_id in np.arange(ncores): - dst_global_start_idx, dst_global_end_idx = resharded_start_and_end[core_id][1] - - for dst_global_idx in np.arange(dst_global_start_idx, dst_global_end_idx + 1): - dst_core_id = core_id - dst_local_idx = dst_global_idx - dst_global_start_idx - is_pad, src_core_id, src_local_idx = tensor_metadata[dst_global_idx] - if is_pad: - assert src_local_idx == 0 - src_core_id = pad_local - dst_core_id = core_id - if (src_core_id, dst_core_id) not in per_core_gather_data: - per_core_gather_data[(src_core_id, dst_core_id)] = [] - assert src_local_idx < 0xFFFF, "Index overflows uint16_t storage type" - assert dst_local_idx < 0xFFFF, "Index overflows uint16_t storage type" - run_length_encode(per_core_gather_data[(src_core_id, dst_core_id)], src_local_idx, dst_local_idx, is_pad) - - padding_config = [] - local_config = [] - remote_config = [] - - for core_id in range(ncores): - padding_config.append([]) - local_config.append([]) - remote_config.append([]) - - # print("per_core_gather_data", per_core_gather_data) - - for core_key, core_data in per_core_gather_data.items(): - src_core_id, dst_core_id = core_key - - # Padding Encoding: [dst_idx0, num_elems0, dst_idx1, num_elems1, ...] - # Local/Remote encoding: [dst_core_id0, num_elems0, ...G0..., dst_core_id1, num_elems1, ...G1..., ...] - is_padding = src_core_id == pad_local - is_local = dst_core_id == src_core_id - is_remote = not is_padding and not is_local - - if is_padding: - del core_data[0::3] - padding_config[dst_core_id].extend(core_data) - elif is_local: - noc_x, noc_y = core_id_to_physical_coord(dst_core_id) - local_config[src_core_id].extend([noc_x, noc_y, len(core_data)]) - local_config[src_core_id].extend(core_data) - elif remote_read: - assert is_remote - noc_x, noc_y = core_id_to_physical_coord(src_core_id) - remote_config[dst_core_id].extend([noc_x, noc_y, len(core_data)]) - remote_config[dst_core_id].extend(core_data) - else: - assert is_remote - noc_x, noc_y = core_id_to_physical_coord(dst_core_id) - remote_config[src_core_id].extend([noc_x, noc_y, len(core_data)]) - remote_config[src_core_id].extend(core_data) - - # NULL plug - for core_id in range(ncores): - padding_config[core_id].extend([0, 0]) - local_config[core_id].extend([0, 0, 0]) - remote_config[core_id].extend([0, 0, 0]) - - align_up_2d_python_list(padding_config, 0, align_granularity=2) - align_up_2d_python_list(local_config, 0, align_granularity=2) - align_up_2d_python_list(remote_config, 0, align_granularity=2) - - # print("padding_config", padding_config) - # print("local_config", local_config) - # print("remote_config", remote_config) - - max_out_nsticks_per_core = max( - [ - resharded_start_and_end[core_id][1][1] - resharded_start_and_end[core_id][1][0] + 1 - for core_id in range(ncores) - ] - ) - - return padding_config, local_config, remote_config, max_out_nsticks_per_core - - -def validate_untilize_with_halo_kernel_configs( - golden, - input_tensor_shards, - resharded_start_and_end, - padding_config, - local_config, - remote_config, - max_out_nsticks_per_core, - physical_coord_to_core_id=lambda x, y: y, -): - ## using the kernel configs, construct the resulting resharding for each core - ncores = len(resharded_start_and_end) - assert len(input_tensor_shards) == ncores - assert len(golden) == ncores - input_c = len(golden[0][0]) - max_size = 0 - for _, dst in resharded_start_and_end: - start = dst[0] - end = dst[1] - size = end - start + 1 - max_size = size if max_size < size else max_size - pad_val = 0 - - def copy_sticks(reshards, input_tensor_shards, config, src_core_id): - i = 0 - length = 1 - while length > 0: - noc_x = config[i + 0] - noc_y = config[i + 1] - length = config[i + 2] - assert noc_x == 0, "Validation assumes noc_x is always 0" - dst_core_id = physical_coord_to_core_id(noc_x, noc_y) - i += 3 - for j in range(0, length, 3): - src_local_idx = config[i + j + 0] - dst_local_idx = config[i + j + 1] - nsticks = config[i + j + 2] - for k in range(nsticks): - reshards[dst_core_id][dst_local_idx + k] = input_tensor_shards[src_core_id][src_local_idx + k] - i += length - - reshards = {} - for core in np.arange(ncores): - dst_range = resharded_start_and_end[core][1] - curr_size = dst_range[1] - dst_range[0] + 1 - reshards[core] = np.zeros([curr_size, input_c], dtype=int) - - for core in np.arange(ncores): - core_padding_config = padding_config[core] - core_local_config = local_config[core] - core_remote_config = remote_config[core] - - for base_dst_idx, nsticks in zip(core_padding_config[0::2], core_padding_config[1::2]): - for dst_idx in range(base_dst_idx, base_dst_idx + nsticks): - reshards[core][dst_idx] = [pad_val] * input_c - dst_idx += 1 - - copy_sticks(reshards, input_tensor_shards, core_local_config, core) - copy_sticks(reshards, input_tensor_shards, core_remote_config, core) - - assert max_out_nsticks_per_core == max([len(golden[core]) for core in range(ncores)]) - for core in np.arange(ncores): - # print(f'OUTPUT CORE {core}: {reshards[core]}') - # print(f'GOLDEN CORE {core}: {golden[core]}') - assert (reshards[core] == golden[core]).all() diff --git a/ttnn/ttnn/operations/conv2d.py b/ttnn/ttnn/operations/conv2d.py index cc54476115e..4d015bd9529 100644 --- a/ttnn/ttnn/operations/conv2d.py +++ b/ttnn/ttnn/operations/conv2d.py @@ -9,17 +9,6 @@ import warnings import math import ttnn -from ttnn.operations.conv.sliding_window_op_utils import ( - calculate_shard_grid, - roundup, - get_output_dim as get_conv_output_dim, -) -from ttnn.operations.conv.tt_py_composite_conv import ( - SlidingWindowOpParams, - find_closest_common_largest_divisor, - find_closest_largest_divisor, - find_closest_largest_divisor_with_num_padding, -) from ttnn.device import ( is_grayskull, is_wormhole_b0, @@ -37,6 +26,13 @@ def _nearest_32(x): OptimizedConvBlockConfig = ttnn._ttnn.operations.conv2d.OptimizedConvBlockConfig +def get_conv_output_dim(input, window, stride=1, pad=0, dilation=1): + """ + Returns the output dimension of a convolution operation. + """ + return (input + (2 * pad) - dilation * (window - 1) - 1) // stride + 1 + + def convert_conv_weight_tensor_to_tiled_layout(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype=None): """ Converts convolution weights to 2d matrix tiled layout on host @@ -87,206 +83,6 @@ def convert_conv_weight_tensor_to_grouped_layout(conv_weight_tensor, num_groups, ) -# internal. not user facing -class ParallelConfig: - def __init__( - self, - num_cores_y: int, - num_cores_x: int, - num_cores_nhw: int, - shard_scheme: ttnn.TensorMemoryLayout, - shard_orientation: ttnn.ShardOrientation, - ): - # TODO: using core range set would be better - self.grid_size = ttnn.CoreCoord(num_cores_x, num_cores_y) - self.num_cores_nhw = num_cores_nhw - self.shard_scheme = shard_scheme - self.shard_orientation = shard_orientation - - def __eq__(self, other): - if not isinstance(other, ParallelConfig): - return NotImplemented - - return ( - self.grid_size.y == other.grid_size.y - and self.grid_size.x == other.grid_size.x - and self.num_cores_nhw == other.num_cores_nhw - and self.shard_scheme == other.shard_scheme - and self.shard_orientation == other.shard_orientation - ) - - def __ne__(self, other): - if not isinstance(other, ParallelConfig): - return NotImplemented - return not (self == other) - - -# internal helper function. not exposed to user. -def get_shard_grid_from_core_grid(core_grid): - shard_grid = None - if isinstance(core_grid, ttnn.CoreGrid): - grid_coord = ttnn.CoreCoord(core_grid.x - 1, core_grid.y - 1) - shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) - elif isinstance(core_grid, (list, tuple)): - if len(core_grid) != 2: - raise RuntimeError("Invalid core_grid") - if not isinstance(core_grid[0], ttnn.CoreGrid): - raise RuntimeError("Invalid core_grid type") - if not isinstance(core_grid[1], ttnn.CoreGrid): - raise RuntimeError("Invalid core_grid type") - - grid_coord_1 = ttnn.CoreCoord(core_grid[0].x - 1, core_grid[0].y - 1) - grid_coord_2 = ttnn.CoreCoord(core_grid[1].x - 1, core_grid[0].y) - shard_grid = ttnn.CoreRangeSet( - { - ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord_1), - ttnn.CoreRange(ttnn.CoreCoord(0, core_grid[0].y), grid_coord_2), - } - ) - elif isinstance(core_grid, ttnn.CoreRangeSet): - shard_grid = core_grid - else: - raise RuntimeError("Invalid core_grid type") - return shard_grid - - -# internal helper function. not exposed to user. -def determine_parallel_config( - is_1d_systolic, - batch_size, - input_channels, - output_height, - output_width, - output_channels, - device, - config_override=None, - is_out_tiled=True, -): - if config_override is None: - config_override = {} - for k in config_override.keys(): - assert k == "grid_size" or k == "num_cores_nhw" - - conv_out_2d_matrix_height = batch_size * output_height * output_width - # pad height to 32 - conv_out_2d_matrix_height = _nearest_32(conv_out_2d_matrix_height) - - if is_out_tiled: - conv_out_2d_matrix_height_ntiles = (int)(conv_out_2d_matrix_height / 32) - conv_out_2d_matrix_width_ntiles = (int)(_nearest_32(output_channels) / 32) - else: - conv_out_2d_matrix_height_ntiles = conv_out_2d_matrix_height - conv_out_2d_matrix_width_ntiles = output_channels - - compute_with_storage_grid_size = device.compute_with_storage_grid_size() - device_grid_size = (compute_with_storage_grid_size.x, compute_with_storage_grid_size.y) - max_num_cores = device_grid_size[0] * device_grid_size[1] - - def calculate_num_cores_nhw(override): - num_cores_nhw = ( - find_closest_largest_divisor(conv_out_2d_matrix_height_ntiles, max_num_cores) - if is_1d_systolic - else find_closest_largest_divisor_with_num_padding(conv_out_2d_matrix_height_ntiles, device_grid_size[0]) - ) - if override is not None and num_cores_nhw != override: - warnings.warn(f"Overriding config: num_cores_nhw from {num_cores_nhw} to user provided config={override}") - num_cores_nhw = override - return num_cores_nhw - - def calculate_grid_size(num_cores_nhw, override): - if is_1d_systolic: - grid_size = [ - device_grid_size[0] if num_cores_nhw >= device_grid_size[0] else num_cores_nhw, - math.ceil(num_cores_nhw / device_grid_size[0]), - ] # for 1d systolic array, grid size is the tightest bound of num_cores_nhw as a rectangle (x,y) - assert ( - num_cores_nhw <= grid_size[0] * grid_size[1] - ), "Error: For 1d systolic conv, num_cores_nhw must be <= grid size" - else: - grid_size = [ - num_cores_nhw, - find_closest_common_largest_divisor( - conv_out_2d_matrix_width_ntiles, _nearest_32(input_channels) // 32, device_grid_size[1] - ), - ] - assert ( - num_cores_nhw == grid_size[0] - ), "Error: For 2d systolic conv, num_cores_nhw must be == # of cols in grid size" - - if override is not None and grid_size != override: - warnings.warn(f"Overriding config: grid_size from {grid_size} to user provided config={override}") - grid_size = override - return grid_size - - num_cores_nhw = calculate_num_cores_nhw(config_override.get("num_cores_nhw", None)) - grid_size = calculate_grid_size(num_cores_nhw, config_override.get("grid_size", None)) - shard_scheme = ttnn.TensorMemoryLayout.HEIGHT_SHARDED if is_1d_systolic else ttnn.TensorMemoryLayout.BLOCK_SHARDED - shard_orientation = ttnn.ShardOrientation.ROW_MAJOR if is_1d_systolic else ttnn.ShardOrientation.COL_MAJOR - return ParallelConfig(grid_size[1], grid_size[0], num_cores_nhw, shard_scheme, shard_orientation) - - -# internal helper function. not exposed to user. -def get_grid_size_and_num_cores_nhw_from_core_grid(core_grid, height_sharded): - if isinstance(core_grid, ttnn.CoreGrid): - if height_sharded: - num_cores_nhw = core_grid.x * core_grid.y - else: - num_cores_nhw = core_grid.x - grid_size = core_grid - elif isinstance(core_grid, (list, tuple)): - if len(core_grid) != 2: - raise RuntimeError("Invalid core_grid") - if not isinstance(core_grid[0], ttnn.CoreGrid): - raise RuntimeError("Invalid core_grid type") - if not isinstance(core_grid[1], ttnn.CoreGrid): - raise RuntimeError("Invalid core_grid type") - assert height_sharded - num_cores_nhw = (core_grid[0].x * core_grid[0].y) + core_grid[1].x - elif isinstance(core_grid, ttnn.CoreRangeSet): - grid_size = core_grid.bounding_box().grid_size() - num_cores = core_grid.num_cores() - if height_sharded: - num_cores_nhw = num_cores - else: - num_cores_nhw = grid_size.x - else: - raise RuntimeError("Invalid core_grid type") - return grid_size, num_cores_nhw - - -# internal helper function. not exposed to user. -def create_sharded_memory_config_from_parallel_config(tensor_shape, parallel_config, tile_size): - logger.debug( - f"py create_sharded_memory_config_from_parallel_config: {tensor_shape}, {parallel_config.num_cores_nhw} {parallel_config.grid_size}, {tile_size}" - ) - # tensor_shape is [N, H, W, C] - assert len(tensor_shape) == 4 - assert tensor_shape[0] == 1 and tensor_shape[1] == 1 # todo: add support for generic non-2d shapes - channels = tensor_shape[3] - channels_padded = roundup(channels, tile_size) - num_cores_nhw = parallel_config.num_cores_nhw - num_cores_x = parallel_config.grid_size.x - num_cores_y = parallel_config.grid_size.y - shard_scheme = parallel_config.shard_scheme - shard_orientation = parallel_config.shard_orientation - is_1d_systolic = shard_scheme == ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if is_1d_systolic: - logical_grid_size = (num_cores_nhw, 1) - else: - logical_grid_size = (num_cores_x, num_cores_y) - - shard_grid, shard_layout = calculate_shard_grid((num_cores_x, num_cores_y), num_cores_nhw) - assert shard_layout == shard_scheme - nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2] - nhw_padded = roundup(nhw_shape, num_cores_nhw * tile_size) - nhw_shard = nhw_padded // num_cores_nhw - assert channels_padded % logical_grid_size[1] == 0 - shard_shape = [nhw_shard, channels_padded // logical_grid_size[1]] - shard_halo = False - shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orientation, shard_halo) - return ttnn.MemoryConfig(shard_scheme, ttnn.BufferType.L1, shard_spec) - - @ttnn.register_python_operation(name="ttnn.conv2d") def conv2d( *,