From 52e9593fd43fe0368f836b88b28d5f5a9bf3ed4c Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Thu, 30 Nov 2023 05:46:17 +0000 Subject: [PATCH] #4059: Add PADDED, BINARY modes to embeddings for optimized cache lookups on repeated tokens Update RM embeddings to handle new input format and embed types. Remove split_weights arg --- .../metal_BERT_large_11/tt/embeddings.py | 51 ++- .../unit_testing/test_embedding.py | 18 +- ...ick_layout_blocks_interleaved_start_id.cpp | 2 +- ...ick_layout_blocks_interleaved_start_id.cpp | 2 +- ...nary_stick_layout_interleaved_start_id.cpp | 2 +- .../op_library/embeddings/embeddings_op.cpp | 382 ++++++++++-------- .../op_library/embeddings/embeddings_op.hpp | 50 +-- .../kernels/dataflow/embeddings.cpp | 197 +++++---- .../kernels/dataflow/embeddings_tilize.cpp | 64 ++- ...ut_sharded_blocks_interleaved_start_id.cpp | 2 +- .../multi_core/sharded_op_multi_core.cpp | 4 +- ...nary_stick_layout_interleaved_start_id.cpp | 2 +- .../tt_lib/csrc/tt_lib_bindings_tensor.cpp | 8 +- ttnn/core.py | 5 +- 14 files changed, 481 insertions(+), 308 deletions(-) diff --git a/models/demos/metal_BERT_large_11/tt/embeddings.py b/models/demos/metal_BERT_large_11/tt/embeddings.py index 74b7b02ea70..c945ecb83e2 100644 --- a/models/demos/metal_BERT_large_11/tt/embeddings.py +++ b/models/demos/metal_BERT_large_11/tt/embeddings.py @@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union import torch import tt_lib as ttl -from tt_lib.utils import pad_weight +from models.utility_functions import torch2tt_tensor class TtEmbeddings: @@ -15,6 +15,7 @@ def __init__(self, hugging_face_reference_model, device, model_config, tt_cache_ config = hugging_face_reference_model.config state_dict = hugging_face_reference_model.state_dict() self.embedding_dim = config.hidden_size + self.pad_token = config.pad_token_id base_address = "bert.embeddings" if tt_cache_path is not None: @@ -49,30 +50,45 @@ def __init__(self, hugging_face_reference_model, device, model_config, tt_cache_ ) ).to(device, self.model_config["EMBEDDINGS_LAYERNORM_BETA_MEMCFG"]) else: - self.word_embeddings_weight = ttl.tensor.Tensor( - pad_weight(state_dict[f"{base_address}.word_embeddings.weight"]), + self.word_embeddings_weight = torch2tt_tensor( + state_dict[f"{base_address}.word_embeddings.weight"], + device, + ttl.tensor.Layout.ROW_MAJOR, + model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"], model_config["INPUT_EMBEDDINGS_WEIGHTS_DTYPE"], - ).to(device, model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"]) + ) - self.position_embeddings_weight = ttl.tensor.Tensor( - pad_weight(state_dict[f"{base_address}.position_embeddings.weight"]), + self.position_embeddings_weight = torch2tt_tensor( + state_dict[f"{base_address}.position_embeddings.weight"], + device, + ttl.tensor.Layout.ROW_MAJOR, + model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"], model_config["INPUT_EMBEDDINGS_WEIGHTS_DTYPE"], - ).to(device, model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"]) + ) - self.token_type_embeddings_weight = ttl.tensor.Tensor( - pad_weight(state_dict[f"{base_address}.token_type_embeddings.weight"]), + self.token_type_embeddings_weight = torch2tt_tensor( + state_dict[f"{base_address}.token_type_embeddings.weight"], + device, + ttl.tensor.Layout.ROW_MAJOR, + model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"], model_config["INPUT_EMBEDDINGS_WEIGHTS_DTYPE"], - ).to(device, model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"]) + ) - self.layerNorm_gamma = ttl.tensor.Tensor( + self.layerNorm_gamma = torch2tt_tensor( state_dict[f"{base_address}.LayerNorm.weight"].reshape([1, 1, -1, 32]), + device, + ttl.tensor.Layout.ROW_MAJOR, + model_config["EMBEDDINGS_LAYERNORM_GAMMA_MEMCFG"], model_config["EMBEDDINGS_LAYERNORM_GAMMA_DTYPE"], - ).to(device, model_config["EMBEDDINGS_LAYERNORM_GAMMA_MEMCFG"]) + ) - self.layerNorm_beta = ttl.tensor.Tensor( + self.layerNorm_beta = torch2tt_tensor( state_dict[f"{base_address}.LayerNorm.bias"].reshape([1, 1, -1, 32]), + device, + ttl.tensor.Layout.ROW_MAJOR, + model_config["EMBEDDINGS_LAYERNORM_BETA_MEMCFG"], model_config["EMBEDDINGS_LAYERNORM_BETA_DTYPE"], - ).to(device, model_config["EMBEDDINGS_LAYERNORM_BETA_MEMCFG"]) + ) self.layerNorm_eps = config.layer_norm_eps @@ -120,8 +136,9 @@ def __call__( inputs_embeds = ttl.tensor.embeddings( input_ids, self.word_embeddings_weight, - split_weights=False, tilized=True, + embeddings_type=ttl.tensor.EmbeddingsType.PADDED, + pad_token=self.pad_token, output_mem_config=self.model_config["OUTPUT_EMBEDDINGS_MEMCFG"], ) input_ids.deallocate() @@ -129,8 +146,8 @@ def __call__( token_type_embeddings = ttl.tensor.embeddings( token_type_ids, self.token_type_embeddings_weight, - split_weights=False, tilized=True, + embeddings_type=ttl.tensor.EmbeddingsType.BINARY, output_mem_config=self.model_config["OUTPUT_EMBEDDINGS_MEMCFG"], ) token_type_ids.deallocate() @@ -146,8 +163,8 @@ def __call__( position_embeddings_tt_tensor = ttl.tensor.embeddings( position_ids, self.position_embeddings_weight, - split_weights=False, tilized=True, + embeddings_type=ttl.tensor.EmbeddingsType.GENERIC, output_mem_config=self.model_config["OUTPUT_EMBEDDINGS_MEMCFG"], ) # Deallocate inputs_embeds and token_type_embeddings here to avoid having to move final output diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_embedding.py b/tests/tt_eager/python_api_testing/unit_testing/test_embedding.py index 78a9b89b9e3..7e170b47660 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_embedding.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_embedding.py @@ -2,18 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 -import math -from pathlib import Path -import sys -import time -import os +import pytest import torch import tt_lib as ttl -from tt_lib.utils import is_close -from models.utility_functions import is_wormhole_b0, skip_for_wormhole_b0 +from models.utility_functions import skip_for_wormhole_b0 from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_equal, @@ -21,7 +16,7 @@ def run_embeddings_tests( - batch_size, num_embeddings, embedding_dim, num_rows, dtype, in0_mem_config, out_mem_config, device, fused=False + batch_size, num_embeddings, embedding_dim, num_rows, dtype, in0_mem_config, out_mem_config, device, tilized=False ): torch.manual_seed(1234) @@ -38,9 +33,9 @@ def run_embeddings_tests( input_tensor = tensor.Tensor(input_rows_torch, ttl.tensor.DataType.UINT32).to(dev, in0_mem_config) weights_tensor = tensor.Tensor(weights_torch, dtype).to(dev, in0_mem_config) - ttz = tensor.embeddings(input_tensor, weights_tensor, False, fused, out_mem_config) + ttz = tensor.embeddings(input_tensor, weights_tensor, tilized, output_mem_config=out_mem_config) - if fused: + if tilized: tt_data = ttz.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() else: tt_data = ttz.cpu().to_torch() @@ -57,9 +52,6 @@ def run_embeddings_tests( assert passing_pcc -import pytest - - @skip_for_wormhole_b0() @pytest.mark.parametrize( "out_mem_config", diff --git a/tt_eager/tt_dnn/kernels/dataflow/reader_unary_stick_layout_blocks_interleaved_start_id.cpp b/tt_eager/tt_dnn/kernels/dataflow/reader_unary_stick_layout_blocks_interleaved_start_id.cpp index 7a6a9aa2f3f..d1ff3a6409e 100644 --- a/tt_eager/tt_dnn/kernels/dataflow/reader_unary_stick_layout_blocks_interleaved_start_id.cpp +++ b/tt_eager/tt_dnn/kernels/dataflow/reader_unary_stick_layout_blocks_interleaved_start_id.cpp @@ -13,7 +13,7 @@ void kernel_main() { uint32_t num_sticks_per_block = get_arg_val(3); uint32_t start_id = get_arg_val(4); - constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0) == 1; + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); constexpr bool src0_is_dram = get_compile_time_arg_val(1) == 1; #define src_stick_size_is_pow2 get_compile_time_arg_val(2) == 1 diff --git a/tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_blocks_interleaved_start_id.cpp b/tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_blocks_interleaved_start_id.cpp index 2e57229946e..0ee49d0ef49 100644 --- a/tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_blocks_interleaved_start_id.cpp +++ b/tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_blocks_interleaved_start_id.cpp @@ -13,7 +13,7 @@ void kernel_main() { uint32_t num_sticks_per_block = get_arg_val(3); uint32_t start_id = get_arg_val(4); - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0) == 1; + constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0); constexpr bool dst0_is_dram = get_compile_time_arg_val(1) == 1; #define dst_stick_size_is_pow2 get_compile_time_arg_val(2) == 1 diff --git a/tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp b/tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp index 623d3fb2777..a2478ef66f1 100644 --- a/tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp +++ b/tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp @@ -13,7 +13,7 @@ void kernel_main() { uint32_t num_sticks = get_arg_val(2); uint32_t start_id = get_arg_val(3); - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0) == 1; + constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0); constexpr bool dst0_is_dram = get_compile_time_arg_val(1) == 1; #define dst_stick_size_is_pow2 get_compile_time_arg_val(2) == 1 diff --git a/tt_eager/tt_dnn/op_library/embeddings/embeddings_op.cpp b/tt_eager/tt_dnn/op_library/embeddings/embeddings_op.cpp index 9b871393e3e..1d51d51a311 100644 --- a/tt_eager/tt_dnn/op_library/embeddings/embeddings_op.cpp +++ b/tt_eager/tt_dnn/op_library/embeddings/embeddings_op.cpp @@ -21,7 +21,11 @@ namespace tt { namespace tt_metal { operation::ProgramWithCallbacks embeddings_tilized( - const Tensor &a, const Tensor &weights, Tensor &output, bool split_weights) { + const Tensor &a, + const Tensor &weights, + Tensor &output, + EmbeddingsType embeddings_type, + std::optional pad_token) { //////////////////////////////////////////////////////////////////////////// // Buffer Setup //////////////////////////////////////////////////////////////////////////// @@ -68,12 +72,7 @@ operation::ProgramWithCallbacks embeddings_tilized( uint32_t start_core_x = 0; uint32_t start_core_y = 0; - uint32_t problem_size; - if (split_weights) { - problem_size = num_embeddings; - } else { - problem_size = num_blocks; - } + uint32_t problem_size = num_blocks; auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); uint32_t num_cores_x = compute_with_storage_grid_size.x; @@ -92,10 +91,12 @@ operation::ProgramWithCallbacks embeddings_tilized( tt::DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype()); uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_cb_data_format); + uint32_t buffering = weights.shape()[-1] > 2048 ? 1 : 2; + uint32_t src0_cb_index = 0; tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig( - 2 * num_tiles_per_block * weights_single_tile_size, {{src0_cb_index, weights_cb_data_format}}) + buffering * num_tiles_per_block * weights_single_tile_size, {{src0_cb_index, weights_cb_data_format}}) .set_page_size(src0_cb_index, weights_single_tile_size); auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); @@ -105,10 +106,26 @@ operation::ProgramWithCallbacks embeddings_tilized( .set_page_size(src1_cb_index, TILE_HEIGHT * input_element_size_bytes); auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config); + if (embeddings_type == EmbeddingsType::PADDED) { + uint32_t src2_cb_index = 2; + uint32_t cache_page_size = round_up_to_mul32(weight_page_size); + tt_metal::CircularBufferConfig cb_src2_config = + tt_metal::CircularBufferConfig(cache_page_size, {{src2_cb_index, weights_cb_data_format}}) + .set_page_size(src2_cb_index, cache_page_size); + auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); + } else if (embeddings_type == EmbeddingsType::BINARY) { + uint32_t src2_cb_index = 2; + uint32_t cache_page_size = round_up_to_mul32(weight_page_size); + tt_metal::CircularBufferConfig cb_src2_config = + tt_metal::CircularBufferConfig(2 * cache_page_size, {{src2_cb_index, weights_cb_data_format}}) + .set_page_size(src2_cb_index, cache_page_size); + auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); + } + uint32_t output_cb_index = 16; // output operands start at index 16 tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig( - 2 * num_tiles_per_block * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) + buffering * num_tiles_per_block * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) .set_page_size(output_cb_index, output_single_tile_size); auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); @@ -131,12 +148,17 @@ operation::ProgramWithCallbacks embeddings_tilized( (std::uint32_t)num_tiles_per_block, (std::uint32_t)TILE_HEIGHT * input_element_size_bytes}; + std::map embedding_defines = {{magic_enum::enum_name(embeddings_type).data(), "1"}}; + auto reader_kernel_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings_tilize.cpp", all_cores, tt_metal::DataMovementConfig{ - .processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default, .compile_args = embedding_compile_time_args}); + .processor = DataMovementProcessor::RISCV_1, + .noc = NOC::RISCV_1_default, + .compile_args = embedding_compile_time_args, + .defines = embedding_defines}); if (num_blocks_per_core_group_1 > 0) { vector compute_args_1 = { @@ -170,55 +192,52 @@ operation::ProgramWithCallbacks embeddings_tilized( "tt_eager/tt_dnn/kernels/dataflow/writer_unary_interleaved_start_id.cpp", all_cores, tt_metal::DataMovementConfig{ - .processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default, .compile_args = writer_compile_time_args}); + .processor = DataMovementProcessor::RISCV_0, + .noc = NOC::RISCV_0_default, + .compile_args = writer_compile_time_args}); uint32_t input_offset = 0; uint32_t weight_offset = 0; uint32_t tile_offset = 0; auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false); + std::vector reader_runtime_args = { + (std::uint32_t)a.buffer()->address(), + (std::uint32_t)weights.buffer()->address(), + (std::uint32_t)0, + (std::uint32_t)0, + (std::uint32_t)0, + }; + if (embeddings_type == EmbeddingsType::PADDED) { + reader_runtime_args.push_back(pad_token.value()); + } + + std::vector writer_runtime_args = { + (std::uint32_t)output.buffer()->address(), (std::uint32_t)0, (std::uint32_t)0}; + for (uint32_t i = 0; i < cores.size(); ++i) { const CoreCoord &core = cores[i]; - uint32_t local_num_blocks; uint32_t local_input_offset = input_offset; - uint32_t local_problem_size = num_blocks_per_core_group_1; - if (i >= g1_numcores) { - local_problem_size = num_blocks_per_core_group_2; - } - if (split_weights) { - local_input_offset = input_offset; - local_num_blocks = num_blocks; - - } else { - local_input_offset = input_offset; - local_num_blocks = local_problem_size; - } + uint32_t local_num_blocks = i < g1_numcores ? num_blocks_per_core_group_1 : num_blocks_per_core_group_2; // Reader { - std::vector runtime_args = { - (std::uint32_t)input_offset / num_blocks_per_batch, - (std::uint32_t)input_offset % num_blocks_per_batch * TILE_HEIGHT * input_element_size_bytes, - (std::uint32_t)(local_num_blocks), - (std::uint32_t)a.buffer()->address(), - (std::uint32_t)weights.buffer()->address()}; - tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); + reader_runtime_args[2] = input_offset / num_blocks_per_batch; + reader_runtime_args[3] = input_offset % num_blocks_per_batch * TILE_HEIGHT * input_element_size_bytes; + reader_runtime_args[4] = local_num_blocks; + tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); } // Writer { - std::vector runtime_args = { - output.buffer()->address(), (uint32_t)num_tiles_per_block * local_num_blocks, tile_offset}; + writer_runtime_args[1] = num_tiles_per_block * local_num_blocks; + writer_runtime_args[2] = tile_offset; tile_offset += local_num_blocks * num_tiles_per_block; - tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); + tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, writer_runtime_args); } - if (split_weights) { - weight_offset += local_problem_size; - } else { - input_offset += local_problem_size; - } + input_offset += local_num_blocks; } auto override_runtime_args_callback = [num_cores_x, num_cores_y, reader_kernel_id, writer_kernel_id, cores, device]( @@ -232,8 +251,8 @@ operation::ProgramWithCallbacks embeddings_tilized( for (const auto &core : cores) { { auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[3] = input_dram_buffer->address(); - runtime_args[4] = weights_dram_buffer->address(); + runtime_args[0] = input_dram_buffer->address(); + runtime_args[1] = weights_dram_buffer->address(); } { @@ -247,7 +266,11 @@ operation::ProgramWithCallbacks embeddings_tilized( } operation::ProgramWithCallbacks embeddings_rm( - const Tensor &a, const Tensor &weights, Tensor &output, bool split_weights) { + const Tensor &a, + const Tensor &weights, + Tensor &output, + EmbeddingsType embeddings_type, + std::optional pad_token) { //////////////////////////////////////////////////////////////////////////// // Buffer Setup //////////////////////////////////////////////////////////////////////////// @@ -268,40 +291,37 @@ operation::ProgramWithCallbacks embeddings_rm( //////////////////////////////////////////////////////////////////////////// Program program{}; - uint32_t cb_id = 0; - uint32_t num_tiles_per_cb = 1; bool in0_is_dram = a.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; bool weights_is_dram = weights.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool weights_dtype_is_bfloat16 = weights.dtype() == tt::tt_metal::DataType::BFLOAT16; bool out_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - uint32_t last_dim = 3; - uint32_t element_size_in_bytes = weights.element_size(); + uint32_t input_element_size_bytes = a.element_size(); + uint32_t weights_element_size_bytes = weights.element_size(); + uint32_t output_element_size_bytes = output.element_size(); // row major, page size is last dim - uint32_t single_page_size = weights.shape()[last_dim] * element_size_in_bytes; + uint32_t input_page_size = a.shape()[-1] * input_element_size_bytes; + uint32_t weight_page_size = weights.shape()[-1] * weights_element_size_bytes; + uint32_t output_page_size = output.shape()[-1] * output_element_size_bytes; // weights shape is [1, 1, num_embeddings, num_dim] - uint32_t num_embeddings = weights.shape()[last_dim - 1]; + uint32_t num_embeddings = weights.shape()[-2]; uint32_t batch_size = a.shape()[0]; - uint32_t num_output_rows_per_batch = a.shape()[last_dim - 1]; + uint32_t num_output_rows_per_batch = a.shape()[-1]; uint32_t num_output_rows = num_output_rows_per_batch * batch_size; + constexpr uint32_t alignment = 32; + uint32_t block_height = (alignment / input_element_size_bytes); + uint32_t num_blocks = num_output_rows; + uint32_t num_blocks_per_batch = num_output_rows_per_batch; + + auto num_embedding_dims = weights.shape()[-1]; // setup problem and grid size uint32_t start_core_x = 0; uint32_t start_core_y = 0; - uint32_t problem_size; - if (split_weights) { - problem_size = num_embeddings; - } else { - problem_size = num_output_rows; - } - - // if tilized, then we will use one risc core per tensix for data movement of embedding and the other to read out - // from the tilized kernel else both risc cores will be used for lookup of the embedding table - uint32_t embedding_risc_cores_per_tensix = RISC_CORES_PER_TENSIX; + uint32_t problem_size = num_blocks; auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); uint32_t num_cores_x = compute_with_storage_grid_size.x; @@ -311,121 +331,154 @@ operation::ProgramWithCallbacks embeddings_rm( uint32_t g1_numcores = core_group_1.num_cores(); uint32_t g2_numcores = core_group_2.num_cores(); - // Create Kernels + // Create Buffers + tt::DataFormat input_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype()); + + tt::DataFormat weights_cb_data_format = tt_metal::datatype_to_dataformat_converter(weights.dtype()); + tt::DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.dtype()); - DataFormat weights_df = tt_metal::datatype_to_dataformat_converter(weights.dtype()); uint32_t src0_cb_index = 0; + uint32_t rounded_weight_page_size = round_up_to_mul32(weight_page_size); tt_metal::CircularBufferConfig cb_src0_config = - tt_metal::CircularBufferConfig(2 * single_page_size, {{src0_cb_index, weights_df}}) - .set_page_size(src0_cb_index, single_page_size); + tt_metal::CircularBufferConfig(2 * rounded_weight_page_size, {{src0_cb_index, weights_cb_data_format}}) + .set_page_size(src0_cb_index, rounded_weight_page_size); auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + uint32_t src1_cb_index = 1; + uint32_t index_page_size = round_up_to_mul32(input_element_size_bytes); tt_metal::CircularBufferConfig cb_src1_config = - tt_metal::CircularBufferConfig(2 * single_page_size, {{src1_cb_index, weights_df}}) - .set_page_size(src1_cb_index, single_page_size); + tt_metal::CircularBufferConfig(block_height * index_page_size, {{src1_cb_index, input_cb_data_format}}) + .set_page_size(src1_cb_index, block_height * index_page_size); auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config); - uint32_t src2_cb_index = 2; - tt::DataFormat index_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.dtype()); - uint32_t index_size = a.element_size(); - tt_metal::CircularBufferConfig cb_src2_config = - tt_metal::CircularBufferConfig(round_up_to_mul32(index_size), {{src2_cb_index, index_cb_data_format}}) - .set_page_size(src2_cb_index, index_size); - auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); - - uint32_t src3_cb_index = 3; - tt_metal::CircularBufferConfig cb_src3_config = - tt_metal::CircularBufferConfig(round_up_to_mul32(index_size), {{src3_cb_index, index_cb_data_format}}) - .set_page_size(src3_cb_index, index_size); - auto cb_src3 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src3_config); - - std::vector> compile_time_args(embedding_risc_cores_per_tensix); - std::vector risc_procs = { - tt_metal::DataMovementProcessor::RISCV_0, tt_metal::DataMovementProcessor::RISCV_1}; - std::vector noc_ports = {tt_metal::NOC::RISCV_0_default, tt_metal::NOC::RISCV_1_default}; - - std::vector kernIds(RISC_CORES_PER_TENSIX); - - for (int risc_id = 0; risc_id < embedding_risc_cores_per_tensix; risc_id++) { - std::vector embedding_compile_time_args = { - (std::uint32_t)in0_is_dram, - (std::uint32_t)weights_is_dram, - (std::uint32_t)out_is_dram, - (std::uint32_t)single_page_size, - (std::uint32_t)risc_id, - (std::uint32_t)risc_id + 2}; - - kernIds[risc_id] = tt_metal::CreateKernel( - program, - "tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings.cpp", - all_cores, - tt_metal::DataMovementConfig{ - .processor = risc_procs[risc_id], - .noc = noc_ports[risc_id], - .compile_args = embedding_compile_time_args}); + if (embeddings_type == EmbeddingsType::PADDED) { + uint32_t src2_cb_index = 2; + uint32_t cache_page_size = round_up_to_mul32(weight_page_size); + tt_metal::CircularBufferConfig cb_src2_config = + tt_metal::CircularBufferConfig(cache_page_size, {{src2_cb_index, weights_cb_data_format}}) + .set_page_size(src2_cb_index, cache_page_size); + auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); + } else if (embeddings_type == EmbeddingsType::BINARY) { + uint32_t src2_cb_index = 2; + uint32_t cache_page_size = round_up_to_mul32(weight_page_size); + tt_metal::CircularBufferConfig cb_src2_config = + tt_metal::CircularBufferConfig(2 * cache_page_size, {{src2_cb_index, weights_cb_data_format}}) + .set_page_size(src2_cb_index, cache_page_size); + auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); } + uint32_t output_cb_index = src0_cb_index; + + bool input_stick_size_is_power_of_two = is_power_of_two_at_least_32(input_page_size); + uint32_t input_log2_stick_size = input_stick_size_is_power_of_two ? (std::uint32_t)log2(input_page_size) : 0; + bool weight_stick_size_is_power_of_two = is_power_of_two_at_least_32(weight_page_size); + uint32_t weight_log2_stick_size = weight_stick_size_is_power_of_two ? (std::uint32_t)log2(weight_page_size) : 0; + + // Create Kernels + // reader + std::vector embedding_compile_time_args = { + (std::uint32_t)in0_is_dram, + (std::uint32_t)input_stick_size_is_power_of_two, + (std::uint32_t)input_page_size, + (std::uint32_t)input_log2_stick_size, + (std::uint32_t)weights_is_dram, + (std::uint32_t)weight_stick_size_is_power_of_two, + (std::uint32_t)weight_page_size, + (std::uint32_t)weight_log2_stick_size, + (std::uint32_t)block_height, + (std::uint32_t)block_height * input_element_size_bytes}; + + std::map embedding_defines = {{magic_enum::enum_name(embeddings_type).data(), "1"}}; + + auto reader_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings.cpp", + all_cores, + tt_metal::DataMovementConfig{ + .processor = DataMovementProcessor::RISCV_1, + .noc = NOC::RISCV_1_default, + .compile_args = embedding_compile_time_args, + .defines = embedding_defines}); + + bool output_stick_size_is_power_of_two = is_power_of_two_at_least_32(output_page_size); + uint32_t output_log2_stick_size = output_stick_size_is_power_of_two ? (std::uint32_t)log2(output_page_size) : 0; + std::vector writer_compile_time_args = { + (std::uint32_t)output_cb_index, + (std::uint32_t)out_is_dram, + (std::uint32_t)output_stick_size_is_power_of_two, + (std::uint32_t)output_log2_stick_size}; + + // Tilized writer + auto writer_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp", + all_cores, + tt_metal::DataMovementConfig{ + .processor = DataMovementProcessor::RISCV_0, + .noc = NOC::RISCV_0_default, + .compile_args = writer_compile_time_args}); + uint32_t input_offset = 0; uint32_t weight_offset = 0; - uint32_t tile_offset = 0; - const auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false); + auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false); + std::vector reader_runtime_args = { + (std::uint32_t)a.buffer()->address(), + (std::uint32_t)weights.buffer()->address(), + (std::uint32_t)0, + (std::uint32_t)0, + (std::uint32_t)0, + (std::uint32_t)0, + }; + if (embeddings_type == EmbeddingsType::PADDED) { + reader_runtime_args.push_back(pad_token.value()); + } + std::vector writer_runtime_args = { + (std::uint32_t)output.buffer()->address(), (std::uint32_t)output_page_size, (std::uint32_t)0, (std::uint32_t)0}; + for (uint32_t i = 0; i < cores.size(); ++i) { const CoreCoord &core = cores[i]; - uint32_t local_num_output_rows; - for (uint32_t idc = 0; idc < embedding_risc_cores_per_tensix; idc++) { - uint32_t local_num_embeddings; - uint32_t local_weight_offset; - uint32_t local_input_offset = input_offset; - uint32_t local_problem_size = num_blocks_per_core_group_1; - bool core_0 = true; - if (i >= g1_numcores) { - local_problem_size = num_blocks_per_core_group_2; - core_0 = false; - } - if (split_weights) { - local_weight_offset = weight_offset; - weight_offset += local_problem_size; - local_input_offset = input_offset; - local_num_output_rows = num_output_rows; - local_num_embeddings = local_problem_size; - } else { - local_input_offset = input_offset; - input_offset += local_problem_size; - local_weight_offset = weight_offset; - local_num_embeddings = num_embeddings; - local_num_output_rows = local_problem_size; - } - std::vector runtime_args; - - runtime_args = { - (std::uint32_t)local_input_offset, - (std::uint32_t)local_weight_offset, - (std::uint32_t)local_num_embeddings, - (std::uint32_t)local_num_output_rows, - (std::uint32_t)a.buffer()->address(), - (std::uint32_t)weights.buffer()->address(), - (std::uint32_t)output.buffer()->address()}; - tt_metal::SetRuntimeArgs(program, kernIds[idc], core, runtime_args); + uint32_t local_num_blocks = i < g1_numcores ? num_blocks_per_core_group_1 : num_blocks_per_core_group_2; + + // Reader + { + reader_runtime_args[2] = input_offset / num_blocks_per_batch; + reader_runtime_args[3] = + round_down(input_offset % num_blocks_per_batch, block_height) * input_element_size_bytes; + reader_runtime_args[4] = local_num_blocks; + reader_runtime_args[5] = input_offset % num_blocks_per_batch % block_height; + tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); } + + // Writer + { + writer_runtime_args[2] = local_num_blocks; + writer_runtime_args[3] = input_offset; + tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, writer_runtime_args); + } + + input_offset += local_num_blocks; } - auto override_runtime_args_callback = [num_cores_x, num_cores_y, kernIds, embedding_risc_cores_per_tensix, cores]( + auto override_runtime_args_callback = [num_cores_x, num_cores_y, reader_kernel_id, writer_kernel_id, cores, device]( const Program &program, const std::vector &input_buffers, const std::vector &output_buffers) { + auto output_dram_buffer = output_buffers.at(0); + auto input_dram_buffer = input_buffers.at(0); + auto weights_dram_buffer = input_buffers.at(1); + for (const auto &core : cores) { - for (uint32_t idc = 0; idc < embedding_risc_cores_per_tensix; idc++) { - auto input_dram_buffer = input_buffers.at(0); - auto weights_dram_buffer = input_buffers.at(1); - auto output_dram_buffer = output_buffers.at(0); - { - auto &runtime_args = GetRuntimeArgs(program, kernIds[idc], core); - runtime_args[4] = input_dram_buffer->address(); - runtime_args[5] = weights_dram_buffer->address(); - runtime_args[6] = output_dram_buffer->address(); - } + { + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = input_dram_buffer->address(); + runtime_args[1] = weights_dram_buffer->address(); + } + + { + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = output_dram_buffer->address(); } } }; @@ -434,11 +487,16 @@ operation::ProgramWithCallbacks embeddings_rm( } operation::ProgramWithCallbacks embeddings_( - const Tensor &a, const Tensor &weights, Tensor &output, bool split_weights, bool tilized) { + const Tensor &a, + const Tensor &weights, + Tensor &output, + bool tilized, + EmbeddingsType embeddings_type, + std::optional pad_token) { if (tilized) { - return embeddings_tilized(a, weights, output, split_weights); + return embeddings_tilized(a, weights, output, embeddings_type, pad_token); } else { - return embeddings_rm(a, weights, output, split_weights); + return embeddings_rm(a, weights, output, embeddings_type, pad_token); } } @@ -453,12 +511,17 @@ void Embeddings::validate(const std::vector &input_tensors) const { TT_FATAL(weights.shape()[0] == 1 && weights.shape()[1] == 1, "First two dimensions for the weights must be 1"); if (this->tilized) { - TT_FATAL(a.shape()[3] % TILE_HEIGHT == 0); - TT_FATAL(weights.shape()[3] % TILE_WIDTH == 0, "Number of columns in table must be factor of tile width"); + TT_FATAL(a.shape()[-1] % TILE_HEIGHT == 0); + TT_FATAL(weights.shape()[-1] % TILE_WIDTH == 0, "Number of columns in table must be factor of tile width"); } else { TT_FATAL(this->output_dtype != DataType::BFLOAT8_B); } TT_FATAL(a.shape()[1] == 1 && a.shape()[2] == 1, "Only dim 0 && 3 for the input can be non 1"); + switch (this->embeddings_type) { + case EmbeddingsType::PADDED: TT_FATAL(this->pad_token.has_value()); break; + case EmbeddingsType::BINARY: TT_FATAL(weights.shape()[-2] == 2); + default: TT_FATAL(!this->pad_token.has_value()); + } } std::vector Embeddings::compute_output_shapes(const std::vector &input_tensors) const { @@ -488,14 +551,15 @@ operation::ProgramWithCallbacks Embeddings::create_program( const auto &a = input_tensors.at(0); const auto &weights = input_tensors.at(1); auto &output_tensor = output_tensors.at(0); - return embeddings_(a, weights, output_tensor, this->split_weights, this->tilized); + return embeddings_(a, weights, output_tensor, this->tilized, this->embeddings_type, this->pad_token); } tt::stl::reflection::Attributes Embeddings::attributes() const { return { {"output_mem_config", this->output_mem_config}, - {"split_weights", this->split_weights}, {"tilized", this->tilized}, + {"embeddings_type", this->embeddings_type}, + {"pad_token", this->pad_token}, {"output_dtype", this->output_dtype}}; } diff --git a/tt_eager/tt_dnn/op_library/embeddings/embeddings_op.hpp b/tt_eager/tt_dnn/op_library/embeddings/embeddings_op.hpp index 7a0b546e91d..c73cc80b38c 100644 --- a/tt_eager/tt_dnn/op_library/embeddings/embeddings_op.hpp +++ b/tt_eager/tt_dnn/op_library/embeddings/embeddings_op.hpp @@ -6,48 +6,50 @@ #include -#include "tt_eager/tensor/tensor.hpp" - #include "tt_dnn/op_library/run_operation.hpp" +#include "tt_eager/tensor/tensor.hpp" using namespace tt::constants; - namespace tt { namespace tt_metal { -struct Embeddings { +enum class EmbeddingsType { GENERIC = 0, PADDED = 1, BINARY = 2 }; +struct Embeddings { const MemoryConfig output_mem_config; - const bool split_weights; const bool tilized; + const EmbeddingsType embeddings_type; + const std::optional pad_token; const DataType output_dtype; + void validate(const std::vector &input_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( - const std::vector& input_tensors, - std::vector &output_tensors - ) const; + const std::vector &input_tensors, std::vector &output_tensors) const; tt::stl::reflection::Attributes attributes() const; - }; -inline Tensor embeddings(const Tensor &input_tensor, const Tensor &weights, - bool splitWeights = true, - bool tilized = true, - const MemoryConfig& mem_config= operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - std::optional output_dtype=std::nullopt){ - return operation::run_without_autoformat(Embeddings{ - .output_mem_config=mem_config, - .split_weights= splitWeights, - .tilized = tilized, - .output_dtype = output_dtype.value_or(weights.dtype())}, - {input_tensor, weights}).at(0); - +inline Tensor embeddings( + const Tensor &input_tensor, + const Tensor &weights, + bool tilized = true, + EmbeddingsType embeddings_type = EmbeddingsType::GENERIC, + std::optional pad_token = std::nullopt, + const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_dtype = std::nullopt) { + return operation::run_without_autoformat( + Embeddings{ + .output_mem_config = mem_config, + .tilized = tilized, + .embeddings_type = embeddings_type, + .pad_token = pad_token, + .output_dtype = output_dtype.value_or(weights.dtype())}, + {input_tensor, weights}) + .at(0); } - -} -} // namespace tt::tt_metal +} // namespace tt_metal +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings.cpp b/tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings.cpp index be4e13593f8..3f849ea01e9 100644 --- a/tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings.cpp +++ b/tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings.cpp @@ -4,85 +4,136 @@ #include "dataflow_api.h" +void kernel_main() { + const std::uint32_t input_dram_buffer_src_addr = get_arg_val(0); + const std::uint32_t weights_dram_buffer_src_addr = get_arg_val(1); + const std::uint32_t batch_offset = get_arg_val(2); + const std::uint32_t weights_offset = get_arg_val(3); + const std::uint32_t num_blocks = get_arg_val(4); -template -inline __attribute__((always_inline)) -void embeddings_( - const uint32_t embeddings_per_core, - const uint32_t input_offset, - const uint32_t weight_offset, - const uint32_t num_output_rows, - const uint32_t page_size, - uint32_t input_l1_addr, - uint32_t weight_l1_addr, - const InterleavedAddrGen& input, - const InterleavedAddrGen& weights, - const InterleavedAddrGen& out - -) { + const std::uint32_t index_idx = get_arg_val(5); - for (uint32_t i = 0; i < num_output_rows; i++) { - auto noc_input_src_addr = get_noc_addr(i+input_offset, input); - noc_async_read(noc_input_src_addr, input_l1_addr, sizeof(uint32_t)); + #define in_is_dram get_compile_time_arg_val(0) == 1 + #define in_stick_size_is_power_of_two get_compile_time_arg_val(1) == 1 + constexpr uint32_t input_page_size = get_compile_time_arg_val(2); + #if (in_stick_size_is_power_of_two) + constexpr uint32_t log_base_2_of_input_page_size = get_compile_time_arg_val(3); + const InterleavedPow2AddrGen input = { + .bank_base_address = input_dram_buffer_src_addr, + .log_base_2_of_page_size = log_base_2_of_input_page_size // TODO(AP): refactor + }; + #else + const InterleavedAddrGen input = { + .bank_base_address = input_dram_buffer_src_addr, + .page_size = input_page_size + }; + #endif + + #define weights_is_dram get_compile_time_arg_val(4) == 1 + #define weight_stick_size_is_power_of_two get_compile_time_arg_val(5) == 1 + constexpr uint32_t weight_stick_size = get_compile_time_arg_val(6); + #if (weight_stick_size_is_power_of_two) + constexpr uint32_t log_base_2_of_weights_page_size = get_compile_time_arg_val(7); + const InterleavedPow2AddrGen weights = { + .bank_base_address = weights_dram_buffer_src_addr, + .log_base_2_of_page_size = log_base_2_of_weights_page_size // TODO(AP): refactor + }; + #else + const InterleavedAddrGen weights = { + .bank_base_address = weights_dram_buffer_src_addr, + .page_size = weight_stick_size + }; + #endif + + constexpr uint32_t rows_per_block = get_compile_time_arg_val(8); + constexpr uint32_t input_block_size_bytes = get_compile_time_arg_val(9); + + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_in1 = 1; + constexpr uint32_t cb_id_in2 = 2; + + constexpr uint32_t tile_height = 32; + + #if defined PADDED + const std::uint32_t pad_token = get_arg_val(6); + uint64_t pad_noc_addr; + { + cb_reserve_back(cb_id_in2, 1); + uint32_t local_pad_addr = get_write_ptr(cb_id_in2); + uint64_t src_noc_addr = get_noc_addr(pad_token, weights); + noc_async_read(src_noc_addr, local_pad_addr, weight_stick_size); noc_async_read_barrier(); - uint32_t row = ((uint32_t *)input_l1_addr)[0]; + pad_noc_addr = get_noc_addr(local_pad_addr); + } + #elif defined BINARY + uint64_t zero_noc_addr, one_noc_addr; + { + cb_reserve_back(cb_id_in2, 2); + uint32_t local_write_addr = get_write_ptr(cb_id_in2); + uint64_t src_noc_addr = get_noc_addr(0, weights); + noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); + zero_noc_addr = get_noc_addr(local_write_addr); + + local_write_addr += weight_stick_size; + src_noc_addr = get_noc_addr(1, weights); + noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); + one_noc_addr = get_noc_addr(local_write_addr); - if(row>=weight_offset && row < (weight_offset+embeddings_per_core)){ - auto noc_src_addr = get_noc_addr(row, weights); - auto noc_dst_addr = get_noc_addr(i+input_offset, out); - noc_async_read(noc_src_addr, weight_l1_addr, page_size); + noc_async_read_barrier(); + } + #endif + + cb_reserve_back(cb_id_in1, 1); + uint32_t input_l1_addr = get_write_ptr(cb_id_in1); + volatile tt_l1_ptr uint32_t* input_l1_ptr = reinterpret_cast(input_l1_addr); + + auto read_block = [&] (const uint32_t& token_idx, const uint32_t& width_size) { + cb_reserve_back(cb_id_in0, 1); + uint32_t l1_write_addr = get_write_ptr(cb_id_in0); + uint64_t src_noc_addr; + uint32_t token = input_l1_ptr[token_idx]; + #if defined PADDED + if (token == pad_token) { + src_noc_addr = pad_noc_addr; + } else { + src_noc_addr = get_noc_addr(token, weights); + } + #elif defined BINARY + if (token == 0) { + src_noc_addr = zero_noc_addr; + } else { + src_noc_addr = one_noc_addr; + } + #else + src_noc_addr = get_noc_addr(token, weights); + #endif + noc_async_read(src_noc_addr, l1_write_addr, width_size); + noc_async_read_barrier(); + cb_push_back(cb_id_in0, 1); + }; + + uint32_t curr_row = batch_offset; + uint32_t offset = weights_offset; + uint32_t index = index_idx; + bool read_indices = true; + for (uint32_t i = 0; i < num_blocks; ++i) { + if (read_indices) { + uint64_t noc_input_src_addr = get_noc_addr(curr_row, input) + offset; + noc_async_read(noc_input_src_addr, input_l1_addr, input_block_size_bytes); noc_async_read_barrier(); - noc_async_write(weight_l1_addr, noc_dst_addr, page_size); - noc_async_write_barrier(); + read_indices = false; + } + read_block(index, weight_stick_size); + index++; + if (index == rows_per_block) { + index = 0; + read_indices = true; + offset += input_block_size_bytes; + if (offset == input_page_size) { + offset = 0; + curr_row++; + } } } - - - -} - -void kernel_main() { - std::uint32_t input_offset = get_arg_val(0); - std::uint32_t weight_offset = get_arg_val(1); - std::uint32_t embeddings_per_core = get_arg_val(2); - std::uint32_t num_output_rows = get_arg_val(3); - std::uint32_t input_dram_buffer_src_addr = get_arg_val(4); - std::uint32_t weights_dram_buffer_src_addr = get_arg_val(5); - std::uint32_t output_dram_buffer_dst_addr = get_arg_val(6); - - - #define in_is_dram get_compile_time_arg_val(0) == 1 - #define weights_is_dram get_compile_time_arg_val(1) == 1 - #define out_is_dram get_compile_time_arg_val(2) == 1 - constexpr uint32_t page_size = get_compile_time_arg_val(3); - constexpr uint32_t cb_id_inter_weights = get_compile_time_arg_val(4); - constexpr uint32_t cb_id_inter_input = get_compile_time_arg_val(5); - - const InterleavedAddrGen input_rows_0 = { - .bank_base_address = input_dram_buffer_src_addr, .page_size = sizeof(uint32_t)}; - const InterleavedAddrGen weights_0 = { - .bank_base_address = weights_dram_buffer_src_addr , .page_size = page_size}; - const InterleavedAddrGen out_0 = { - .bank_base_address = output_dram_buffer_dst_addr , .page_size = page_size}; - cb_reserve_back(cb_id_inter_input, 1); - uint32_t inter_input_l1_addr = get_write_ptr(cb_id_inter_input); - cb_reserve_back(cb_id_inter_weights, 1); - uint32_t inter_weights_l1_addr = get_write_ptr(cb_id_inter_weights); - - embeddings_( - embeddings_per_core, - input_offset, - weight_offset, - num_output_rows, - page_size, - inter_input_l1_addr, - inter_weights_l1_addr, - input_rows_0, - weights_0, - out_0 - ); - - - } diff --git a/tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings_tilize.cpp b/tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings_tilize.cpp index 9e32f63508e..e17e3600816 100644 --- a/tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings_tilize.cpp +++ b/tt_eager/tt_dnn/op_library/embeddings/kernels/dataflow/embeddings_tilize.cpp @@ -5,12 +5,11 @@ #include "dataflow_api.h" void kernel_main() { - std::uint32_t batch_offset = get_arg_val(0); - std::uint32_t weights_offset = get_arg_val(1); - std::uint32_t num_blocks = get_arg_val(2); - std::uint32_t input_dram_buffer_src_addr = get_arg_val(3); - std::uint32_t weights_dram_buffer_src_addr = get_arg_val(4); - + const std::uint32_t input_dram_buffer_src_addr = get_arg_val(0); + const std::uint32_t weights_dram_buffer_src_addr = get_arg_val(1); + const std::uint32_t batch_offset = get_arg_val(2); + const std::uint32_t weights_offset = get_arg_val(3); + const std::uint32_t num_blocks = get_arg_val(4); #define in_is_dram get_compile_time_arg_val(0) == 1 #define in_stick_size_is_power_of_two get_compile_time_arg_val(1) == 1 @@ -48,19 +47,64 @@ void kernel_main() { constexpr uint32_t cb_id_in0 = 0; constexpr uint32_t cb_id_in1 = 1; + constexpr uint32_t cb_id_in2 = 2; constexpr uint32_t tile_height = 32; - uint64_t base_src_noc_addr[tile_height]; + #if defined PADDED + const std::uint32_t pad_token = get_arg_val(5); + uint64_t pad_noc_addr; + { + cb_reserve_back(cb_id_in2, 1); + uint32_t local_pad_addr = get_write_ptr(cb_id_in2); + uint64_t src_noc_addr = get_noc_addr(pad_token, weights); + noc_async_read(src_noc_addr, local_pad_addr, weight_stick_size); + noc_async_read_barrier(); + pad_noc_addr = get_noc_addr(local_pad_addr); + } + #elif defined BINARY + uint64_t zero_noc_addr, one_noc_addr; + { + cb_reserve_back(cb_id_in2, 2); + uint32_t local_write_addr = get_write_ptr(cb_id_in2); + uint64_t src_noc_addr = get_noc_addr(0, weights); + noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); + zero_noc_addr = get_noc_addr(local_write_addr); + + local_write_addr += weight_stick_size; + src_noc_addr = get_noc_addr(1, weights); + noc_async_read(src_noc_addr, local_write_addr, weight_stick_size); + one_noc_addr = get_noc_addr(local_write_addr); + + noc_async_read_barrier(); + } + #endif cb_reserve_back(cb_id_in1, 1); uint32_t input_l1_addr = get_write_ptr(cb_id_in1); volatile tt_l1_ptr uint32_t* input_l1_ptr = reinterpret_cast(input_l1_addr); - auto read_tiles = [&input_l1_ptr, &weights] (const uint32_t& num_tiles, const uint32_t& width_size) { + + auto read_tiles = [&] (const uint32_t& num_tiles, const uint32_t& width_size) { cb_reserve_back(cb_id_in0, num_tiles); uint32_t l1_write_addr = get_write_ptr(cb_id_in0); - for (uint32_t k = 0; k < tile_height; k++) { - uint64_t src_noc_addr = get_noc_addr(input_l1_ptr[k], weights); + for (uint32_t k = 0; k < tile_height; ++k) { + uint64_t src_noc_addr; + uint32_t token = input_l1_ptr[k]; + #if defined PADDED + if (token == pad_token) { + src_noc_addr = pad_noc_addr; + } else { + src_noc_addr = get_noc_addr(token, weights); + } + #elif defined BINARY + if (token == 0) { + src_noc_addr = zero_noc_addr; + } else { + src_noc_addr = one_noc_addr; + } + #else + src_noc_addr = get_noc_addr(token, weights); + #endif noc_async_read(src_noc_addr, l1_write_addr, width_size); l1_write_addr += width_size; } diff --git a/tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp b/tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp index 3c19bc7175b..d8d742d9cd1 100644 --- a/tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp +++ b/tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp @@ -14,7 +14,7 @@ void kernel_main() { const uint32_t input_width_offset_bytes = get_arg_val(4); const uint32_t start_id = get_arg_val(5); - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0) == 1; + constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0); constexpr bool dst0_is_dram = get_compile_time_arg_val(1) == 1; #define dst_stick_size_is_pow2 get_compile_time_arg_val(2) == 1 diff --git a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp index d516de32d13..9bc0933c2b8 100644 --- a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp @@ -339,8 +339,8 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, all_cores, {num_units_per_shard}); uint32_t curr_idx_h = 0, curr_idx_w = 0; - const auto cores = - grid_to_cores(num_cores, num_cores_x, num_cores_y, rm_orientation) for (const auto& core : cores) { + const auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, rm_orientation); + for (const auto& core : cores) { if (input.layout() == Layout::TILE) { uint32_t shard_height = num_units_per_shard_height; uint32_t shard_width = num_units_per_shard_width; diff --git a/tt_eager/tt_dnn/op_library/unpad/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp b/tt_eager/tt_dnn/op_library/unpad/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp index 2825973e2e9..13c9f1c2855 100644 --- a/tt_eager/tt_dnn/op_library/unpad/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp @@ -13,7 +13,7 @@ void kernel_main() { uint32_t num_sticks = get_arg_val(2); uint32_t start_id = get_arg_val(3); - constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0) == 1; + constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0); constexpr bool dst0_is_dram = get_compile_time_arg_val(1) == 1; diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp index 4348c545c8a..8a3677f04c0 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp @@ -121,6 +121,8 @@ void TensorModule(py::module &m_tensor) { py::implicitly_convertible, UnaryWithParam>(); py::implicitly_convertible, UnaryWithParam>(); + detail::export_enum(m_tensor); + auto py_core_coord = py::class_(m_tensor, "CoreCoord", R"doc( Class defining core coordinate )doc"); @@ -385,8 +387,9 @@ void TensorModule(py::module &m_tensor) { // input embeddings m_tensor.def("embeddings", &embeddings, py::arg("input").noconvert(), py::arg("weights").noconvert(), - py::arg("split_weights").noconvert() = false, py::arg("tilized").noconvert() = false, + py::arg("embeddings_type").noconvert() = EmbeddingsType::GENERIC, + py::arg("pad_token").noconvert() = std::nullopt, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_dtype").noconvert() = std::nullopt, R"doc( Returns specific indices of the embedding table specified by the input tensor @@ -396,8 +399,9 @@ void TensorModule(py::module &m_tensor) { "input", "Tensor containing rows we want", "UInt32 Tensor", "Each element greater than 0 and less than number of embeddings in table. Shape [batch_size, 1, num_rows, 1]", "Yes" "weights", "Entire embedding table", "Tensor", "Tensor shape is [1,1, num_embeddings, num_columns]. Num_columns must be divisible by 32.", "Yes" - "split_weights", "Parallelizing over weights (instead of input). Default is false", "Bool", "", "No" "tilized", "Enable fused tilize on output. Default is true.", "Bool", "", "No", + "embeddings_type", "Version of optimized embeddings to run. PADDED requires passing pad_token. BINARY expects the indices to only be 0, 1 and weights to have 2 rows", "EmbeddingsType", "GENERIC, PADDED, BINARY", "No" + "pad_token", "pad_token used in token ids", "uint32_t", "Default is None", "No" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" "output_dtype", "DataType of output tensor", "DataType", "Default is weights dtype", "No" )doc"); diff --git a/ttnn/core.py b/ttnn/core.py index 63b3d55206a..37042098e38 100644 --- a/ttnn/core.py +++ b/ttnn/core.py @@ -1042,12 +1042,11 @@ def embedding( weights = _reshape_to_4D(weights) *_, batch_size, sentence_size = input_tensor.shape - input_tensor = reshape(input_tensor, shape=(batch_size, 1, sentence_size, 1)) + input_tensor = reshape(input_tensor, shape=(batch_size, 1, 1, sentence_size)) - split_weights = False tilized = layout == TILE_LAYOUT embeddings = Tensor( - ttl.tensor.embeddings(input_tensor._tensor, weights._tensor, split_weights, tilized, memory_config) + ttl.tensor.embeddings(input_tensor._tensor, weights._tensor, tilized, output_mem_config=memory_config) ) embeddings = reshape(embeddings, shape=(batch_size, sentence_size, hidden_embedding_dim))