diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 0a3b75d4b51..7f4247fb5d1 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -6,12 +6,19 @@ import torch import pytest -from models.utility_functions import skip_for_wormhole_b0, skip_for_grayskull, is_grayskull, is_wormhole_b0 +from models.utility_functions import ( + skip_for_wormhole_b0, + skip_for_grayskull, + is_grayskull, + is_wormhole_b0, + is_x2_harvested, +) from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc, check_with_pcc_without_tensor_printout import ttnn import tt_lib import math import os +import torch.nn as nn # def plot_diff(vals, fid, nsticks, stick_len): @@ -55,12 +62,13 @@ def run_conv( output_layout=ttnn.TILE_LAYOUT, deallocate_activation=False, debug=False, + groups=1, ): # has_bias = False has_bias = True torch.manual_seed(0) conv_input_shape = [batch_size, input_channels, input_height, input_width] - conv_weight_shape = [output_channels, input_channels, filter_height, filter_width] + conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width] conv_bias_shape = [1, 1, 1, output_channels] torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float() torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1)) @@ -72,6 +80,7 @@ def run_conv( bias=torch_bias_tensor.reshape(-1) if has_bias else None, stride=(stride_h, stride_w), padding=(pad_h, pad_w), + groups=groups, ) output_shape_nhwc = [ torch_out_golden_tensor.shape[0], @@ -123,6 +132,7 @@ def run_conv( conv_op_cache=reader_patterns_cache, reshard_if_not_optimal=False, debug=debug, + groups=groups, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -1239,3 +1249,210 @@ def test_conv_core_nondivis( use_1d_systolic_array, config_override, ) + + +# The following test takes various shape sizes from resnet50, unet and stable diffusion and tests for different number of groups - all the way to num_groups = num_in_channels (depthwise conv) +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "batch_size, input_channels, output_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, groups, use_1d_systolic_array, config_override, use_shallow_conv_variant", + ( + (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, 2, True, None, False), + (1, 64, 64, 32, 32, 3, 3, 1, 1, 1, 1, 64, True, None, False), + (2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 1, True, None, False), + (2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 2, True, None, False), + (2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, True, None, False), + (1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 1, True, None, False), + (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 64, True, None, False), + (4, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 128, True, None, False), + (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 128, True, None, False), + # (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 256, False, None, False), circular buffer error + # (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 256, False, None, False), # doesn't fit with bfloat16 weights + # (32, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 512, False, None, False), # doesn't fit with bfloat16 weights + (32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 40, False, None, False), + (32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 10, False, None, False), + (1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, True, None, False), + (1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 16, True, None, False), + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, 32, True, None, False), + (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 2, False, None, False), + (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 4, False, None, False), + (1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, 2, False, None, False), + (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, 320, False, None, False), + # (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, None, False), # doesn't fit with bfloat16 weights + (2, 64, 32, 66, 10, 3, 3, 1, 1, 1, 1, 32, True, None, False), + (2, 32, 96, 132, 20, 3, 3, 1, 1, 1, 1, 2, True, None, False), + ), +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat16], +) +@pytest.mark.parametrize( + "activations_dtype", + [ttnn.bfloat8_b, ttnn.bfloat16], +) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) +@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) +def test_conv_groups( + device, + use_program_cache, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override, + use_shallow_conv_variant, + groups, + output_layout, +): + run_conv( + device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override, + use_shallow_conv_variant=use_shallow_conv_variant, + groups=groups, + output_layout=output_layout, + ) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +@pytest.mark.parametrize( + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant, groups", + ( + # yolov4 convs with batch size 1 + # unique convs in yolov4 (complete list) # groups: number + # (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 32), # groups: 32 + # (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 32), # groups: 32 + # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64 + # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64 + # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64 + # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 + (1, 128, 128, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 2), # groups: 512 + ), +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat16], +) +@pytest.mark.parametrize( + "activations_dtype", + # [ttnn.bfloat8_b, ttnn.bfloat16], + [ttnn.bfloat8_b], +) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) +# @pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) +def test_yolov4_conv_groups_larger_than_one( + device, + use_program_cache, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override, + use_shallow_conv_variant, + groups, + output_layout, +): + if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b: + pytest.skip("Row major layout not compatible with bfloat8_b") + if output_layout == ttnn.ROW_MAJOR_LAYOUT and input_height >= 1056: + pytest.skip("OOM") + run_conv( + device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter_height, + filter_width, + stride_h, + stride_w, + pad_h, + pad_w, + use_1d_systolic_array, + config_override, + use_shallow_conv_variant=use_shallow_conv_variant, + groups=groups, + padded_input_channels=16 if input_channels == 3 else None, + output_layout=output_layout, + ) diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index c54861b4ad2..d85efa6c9f8 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -220,6 +220,120 @@ Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); } +/* +Helper function to aid in converting grouped weight tensor to ungrouped weight tensor with padded zero channels +*/ +template +static Tensor conv_group_weight_zero_pad_helper( + Tensor& conv_weight_tensor, + Shape& original_weight_shape, + Shape& output_weight_shape, + uint32_t num_groups, + DataType output_dtype) { + owned_buffer::Buffer output_buffer = owned_buffer::create(compute_volume(output_weight_shape)); + auto conv_weight_tensor_buffer = borrowed_buffer::get_as(conv_weight_tensor); + + for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) { + int new_batch_idx = curr_batch_idx; + + // Find which group_id the filter belongs to - through this, we can compute the offset where the padding should + // be applied + auto group_size = original_weight_shape[0] / num_groups; + auto group_index = curr_batch_idx / group_size; + auto group_id = std::min(group_index, num_groups - 1); + int new_channel_start_idx = group_id * original_weight_shape[1]; + + for (int j = 0; j < original_weight_shape[1]; j++) { + for (int k = 0; k < original_weight_shape[2]; k++) { + for (int m = 0; m < original_weight_shape[3]; m++) { + // Get value from original weight tensor + auto value_flat_input_index = + compute_flat_indices({curr_batch_idx, j, k, m}, compute_strides(original_weight_shape)); + auto value = conv_weight_tensor_buffer[value_flat_input_index]; + + // Copy value to output tensor at the adjusted position + auto new_channel_idx = new_channel_start_idx + j; + auto output_flat_input_index = compute_flat_indices( + {new_batch_idx, new_channel_idx, k, m}, compute_strides(output_weight_shape)); + output_buffer[output_flat_input_index] = value; + } + } + } + } + + auto output_tensor = + Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); + return output_tensor; +} + +/* +Converts convolution weights to grouped layout with padded zeros +This function will take in a weight tensor with shape [out_channels, in_channels // groups, H, W] and return a newly +allocated output tensor with shape [out_channels, in_channels, H, W] The extra channels in shape[1] will be padded with +0 - then the entire weight tensor is convolved with the input tensor - equivalent to convolution if the input tensor was +divided into num_groups for each groupped filter +*/ +Tensor convert_conv_weight_tensor_to_grouped_layout( + Tensor conv_weight_tensor, uint32_t num_groups, DataType output_dtype) { + TT_ASSERT( + conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && + "Convolution weights should be in row major layout for adding the required padding"); + + // Define output tensor shape. This is going to be channel dimension of weight tensor * num_groups - this value + // should match number of input channels being convolved with the weight tensor + auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape(); + Shape original_conv_weight_tensor_shape = { + original_conv_weight_tensor_shape_test[0], + original_conv_weight_tensor_shape_test[1], + original_conv_weight_tensor_shape_test[2], + original_conv_weight_tensor_shape_test[3]}; + Shape output_conv_weight_tensor_shape = { + original_conv_weight_tensor_shape[0], + original_conv_weight_tensor_shape[1] * num_groups, + original_conv_weight_tensor_shape[2], + original_conv_weight_tensor_shape[3]}; + + // Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor + if (output_dtype == DataType::INT32) { + return conv_group_weight_zero_pad_helper( + conv_weight_tensor, + original_conv_weight_tensor_shape, + output_conv_weight_tensor_shape, + num_groups, + output_dtype); + } else if (output_dtype == DataType::FLOAT32) { + return conv_group_weight_zero_pad_helper( + conv_weight_tensor, + original_conv_weight_tensor_shape, + output_conv_weight_tensor_shape, + num_groups, + output_dtype); + } else if (output_dtype == DataType::BFLOAT16) { + return conv_group_weight_zero_pad_helper( + conv_weight_tensor, + original_conv_weight_tensor_shape, + output_conv_weight_tensor_shape, + num_groups, + output_dtype); + } else if (output_dtype == DataType::UINT16) { + return conv_group_weight_zero_pad_helper( + conv_weight_tensor, + original_conv_weight_tensor_shape, + output_conv_weight_tensor_shape, + num_groups, + output_dtype); + } else { + return conv_group_weight_zero_pad_helper( + conv_weight_tensor, + original_conv_weight_tensor_shape, + output_conv_weight_tensor_shape, + num_groups, + output_dtype); + } + + TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor"); +} + const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volume) { vector ns{N, C, H, W}; int neg_idx = -1; diff --git a/tt_eager/tensor/tensor_utils.hpp b/tt_eager/tensor/tensor_utils.hpp index f6b9b740060..406b52e4139 100644 --- a/tt_eager/tensor/tensor_utils.hpp +++ b/tt_eager/tensor/tensor_utils.hpp @@ -27,6 +27,9 @@ Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( uint32_t in1_block_w, std::optional output_dtype = std::nullopt); +// Converts convolution weights to grouped layout with padded zeros +Tensor convert_conv_weight_tensor_to_grouped_layout(Tensor conv_weight_tensor, uint32_t num_groups, DataType output_dtype); + const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volume); const Shape infer_dims_for_reshape_RM(int N, int C, int H, int W, uint32_t old_volume); @@ -40,6 +43,24 @@ static std::size_t compute_volume(const T& shape) { return volume; } +static std::vector compute_strides(Shape shape) { + auto num_elements = compute_volume(shape); + std::vector strides; + for (std::int32_t index = 0; index < shape.rank(); index++) { + num_elements /= shape[index]; + strides.push_back(num_elements); + } + return strides; +} + +static int compute_flat_indices(vector indices, vector strides) { + int flat_index = 0; + for (auto i = 0; i < indices.size(); i++) { + flat_index += indices[i] * strides[i]; + } + return flat_index; +}; + template static std::size_t compute_buffer_size(const T& shape, DataType data_type) { const auto volume = compute_volume(shape); diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp index 32f3837667a..01f17290f3b 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor.cpp @@ -883,6 +883,23 @@ void TensorModule(py::module& m_tensor) { +----------+----------------------+-----------+-------------+----------+ )doc"); + m_tensor.def( + "convert_conv_weight_tensor_to_grouped_layout", + &convert_conv_weight_tensor_to_grouped_layout, + py::arg("conv_weight_tensor").noconvert(), + py::arg("num_groups"), + py::arg("output_dtype").noconvert() = std::nullopt, + R"doc( + Converts convolution weights to grouped layout with padded zeros + Returns a new tensor with the converted layout. + + +----------+----------------------+-----------+-------------+----------+ + | Argument | Description | Data type | Valid range | Required | + +==========+======================+===========+=============+==========+ + | a | Input tensor | Tensor | | Yes | + +----------+----------------------+-----------+-------------+----------+ + )doc"); + m_tensor.def( "format_input_tensor", &AutoFormat::format_input_tensor, diff --git a/ttnn/cpp/ttnn/operations/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv2d.cpp index 8fac471ddde..b5a5c992817 100644 --- a/ttnn/cpp/ttnn/operations/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv2d.cpp @@ -446,11 +446,19 @@ std::pair> prepare_conv_weights_biases uint32_t weight_block_h_ntiles, uint32_t weight_block_w_ntiles, const ParallelConfig& parallel_config, - Device& device) { + Device& device, + uint32_t groups) { validate_weight_and_bias_tensors(weight_tensor, bias_tensor); ttnn::Tensor weight_tensor_; // tensor to return ttnn::Tensor bias_tensor_; - auto weights_shape = weight_tensor.get_shape(); + + // Convert weight tensor to 0 padded shape if groups > 1 + weight_tensor_ = weight_tensor; + if (groups > 1) { + weight_tensor_ = convert_conv_weight_tensor_to_grouped_layout(weight_tensor_, groups, weights_bias_dtype); + } + + auto weights_shape = weight_tensor_.get_shape(); uint32_t out_channels = weights_shape[0]; uint32_t in_channels = weights_shape[1]; uint32_t window_h = weights_shape[2]; @@ -459,19 +467,19 @@ std::pair> prepare_conv_weights_biases {round_up(out_channels, 32), round_up(in_channels, input_channels_alignment), window_h, window_w})); if (weights_bias_dtype == DataType::BFLOAT8_B) { - TT_ASSERT(weight_tensor.get_dtype() == DataType::FLOAT32); + TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32); if (bias_tensor.has_value()) { TT_ASSERT(bias_tensor.value().get_dtype() == DataType::FLOAT32); } } else { // TODO: fix the need to check this. We should be able to accept any datatype and convert - TT_ASSERT(weight_tensor.get_dtype() == weights_bias_dtype); + TT_ASSERT(weight_tensor_.get_dtype() == weights_bias_dtype); if (bias_tensor.has_value()) { TT_ASSERT(bias_tensor.value().get_dtype() == weights_bias_dtype); } } - weight_tensor_ = tt::tt_metal::pad_on_host(weight_tensor, weights_channels_padded_shape, {0, 0, 0, 0}, 0); + weight_tensor_ = tt::tt_metal::pad_on_host(weight_tensor_, weights_channels_padded_shape, {0, 0, 0, 0}, 0); // for conv op, pad the weights to block shape if (parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { @@ -596,7 +604,8 @@ std::tuple