diff --git a/tests/ttnn/unit_tests/operations/test_tilizer.py b/tests/ttnn/unit_tests/operations/test_tilizer.py new file mode 100644 index 00000000000..ef60c6b1236 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_tilizer.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from loguru import logger + + +def test_device_tilize(device): + """Benchmark host vs. device tilizer for converting torch tensor to tilized tensor.""" + import time + + torch_tensor = torch.randn((4544, 18176), dtype=torch.bfloat16) + output_dtype = ttnn.bfloat8_b + + start = time.time() + tensor = ttnn.from_torch(torch_tensor, dtype=output_dtype, layout=ttnn.TILE_LAYOUT) + end = time.time() + logger.info(f"Time taken to convert to tensor using host-tilizer: {end-start}") + + start = time.time() + tensor = ttnn.from_torch( + torch_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + tensor = ttnn.to_layout(tensor, ttnn.TILE_LAYOUT, dtype=output_dtype, device=device) + end = time.time() + logger.info(f"Time taken to convert to tensor using device-tilizer: {end-start}") diff --git a/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.cpp index 2a34d8609a5..a23ed9f8f05 100644 --- a/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.cpp @@ -16,6 +16,24 @@ namespace operations { namespace core { namespace detail { + +// Issue #8617: Limitations on tensor width for multicore device tilize +inline bool use_multicore_device_tilize( + const Tensor& input, const std::optional& output_dtype) { + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); + + uint32_t output_single_tile_size = + output_dtype.has_value() ? tt::tt_metal::detail::TileSize(tt::tt_metal::datatype_to_dataformat_converter(output_dtype.value())) + : input_single_tile_size; + + uint32_t num_tiles_in_row = input.get_shape()[-1] / TILE_WIDTH; + uint32_t max_l1_size = input.device()->l1_size_per_core() / 2 - L1_UNRESERVED_BASE; + uint32_t max_tiles = max_l1_size / (input_single_tile_size + output_single_tile_size); // 2 CBs + + return num_tiles_in_row <= max_tiles; +} + template Tensor execute( const ttnn::Tensor& tensor_arg, @@ -85,12 +103,13 @@ Tensor execute( memory_config.value_or(ttnn::get_memory_config(tensor).value_or(ttnn::DRAM_MEMORY_CONFIG)); if (ttnn::is_tensor_on_device_or_multidevice(tensor_arg)) { - bool use_multicore = true; + bool use_multicore_untilize = true; + bool use_multicore_tilize = use_multicore_device_tilize(tensor, dtype); if (not requires_padding_change(layout, tensor.get_shape())) { if (layout == ttnn::ROW_MAJOR_LAYOUT) { TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); - return tt::tt_metal::untilize(tensor, output_memory_config, use_multicore); + return tt::tt_metal::untilize(tensor, output_memory_config, use_multicore_untilize); } else if (layout == ttnn::TILE_LAYOUT) { if (tensor.is_sharded()) { const auto shard_shape = get_memory_config(tensor).value().shard_spec.value().shape; @@ -100,7 +119,7 @@ Tensor execute( "TILE_SIZE!"); } } - return tt::tt_metal::tilize(tensor, output_memory_config, dtype, use_multicore); + return tt::tt_metal::tilize(tensor, output_memory_config, dtype, use_multicore_tilize); } else { throw runtime_error("ttnn::to_layout: Unsupported layout!"); } @@ -130,7 +149,7 @@ Tensor execute( padded_4D_output_shape.push_back(ttnn::pad_to_multiple_of_tile_size(tensor.get_shape()[-2])); padded_4D_output_shape.push_back(ttnn::pad_to_multiple_of_tile_size(tensor.get_shape()[-1])); tensor = tt::tt_metal::tilize_with_val_padding( - tensor, padded_4D_output_shape, 0, output_memory_config, dtype, use_multicore); + tensor, padded_4D_output_shape, 0, output_memory_config, dtype, use_multicore_tilize); return reshape(tensor, ttnn::Shape(tt::tt_metal::Shape{output_shape, padded_output_shape})); } else { diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 6c9ff6e79bd..1b8bae19c75 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -596,6 +596,7 @@ def as_tensor( cache_file_name: Optional[Union[str, pathlib.Path]] = None, preprocess: Optional[Callable[[ttnn.Tensor], ttnn.Tensor]] = None, mesh_mapper: Optional[ttnn.TensorToMesh] = None, + use_device_tilizer: bool = False, ) -> ttnn.Tensor: """ as_tensor(tensor: Union[torch.Tensor], dtype: Optional[ttnn.DataType] = None, layout: Optional[ttnn.Layout] = ROW_MAJOR_LAYOUT, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None, cache_file_name: Optional[str | pathlib.Path] = None) -> ttnn.Tensor @@ -611,6 +612,7 @@ def as_tensor( * :attr:`cache_file_name`: the optional cache file name. * :attr:`preprocess`: the optional function to preprocess the tensor before serializing/converting to ttnn. * :attr:`mesh_mapper`: the optional TensorToMesh to define the mapping from torch to multi-device. + * :attr:`use_device_tilizer`: the optional flag to use device tilizer instead of host-tilizer. Example:: @@ -636,7 +638,19 @@ def as_tensor( def from_torch_and_dump(tensor, dtype, layout, cache_file_name): if preprocess: tensor = preprocess(tensor) - tensor = ttnn.from_torch(tensor, dtype=dtype, layout=layout, mesh_mapper=mesh_mapper) + if use_device_tilizer and device and layout == ttnn.TILE_LAYOUT: + # To use the device tilizer, we're going to first move the tensor + # to the device because the on-device tilizer works on bfloat16, on-device tensor. + tensor = ttnn.from_torch( + tensor, + layout=ttnn.ROW_MAJOR_LAYOUT, + mesh_mapper=mesh_mapper, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + tensor = ttnn.to_layout(tensor, layout, dtype=dtype, memory_config=memory_config, device=device) + else: + tensor = ttnn.from_torch(tensor, dtype=dtype, layout=layout, mesh_mapper=mesh_mapper) logger.debug( f"Generating cache for {cache_file_name} of shape {tensor.shape}, dtype {dtype_name}, layout {layout_name}" )