diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index f760b407b85..10e1ab31f5a 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -73,19 +73,28 @@ def run_conv( shard_layout=None, auto_shard=False, memory_config=None, + input_mesh_mapper=None, + weight_mesh_mapper=None, + output_mesh_composer=None, ): + if isinstance(device, ttnn.MeshDevice): + assert input_mesh_mapper is not None, "Expected mesh mapper for input tensor when using device mesh" + assert weight_mesh_mapper is not None, "Expected mesh mapper for weight tensors when using device mesh" + assert output_mesh_composer is not None, "Expected mesh composer for output tensor when using device mesh" + num_devices = len(device.get_device_ids()) + total_batch_size = num_devices * batch_size # Batch size across all devices + logger.info(f"Using {num_devices} devices for this test") + else: + total_batch_size = batch_size + torch.manual_seed(0) - conv_input_shape = [batch_size, input_channels, input_height, input_width] + conv_input_shape = [total_batch_size, input_channels, input_height, input_width] conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width] conv_bias_shape = [1, 1, 1, output_channels] torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float() - # torch_input_tensor_nchw = torch.ones(conv_input_shape, dtype=torch.bfloat16).float() - # torch_input_tensor_nchw = torch.tensor(range(input_height * input_width)).reshape([1,1,input_height,input_width]).float() - # torch_input_tensor_nchw = torch_input_tensor_nchw.broadcast_to(conv_input_shape).float() torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1)) torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float() - # torch_weight_tensor = torch.ones(conv_weight_shape, dtype=torch.bfloat16).float() torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None torch_out_golden_tensor = torch.nn.functional.conv2d( @@ -107,15 +116,19 @@ def run_conv( reader_patterns_cache = {} tt_weight_tensor = ttnn.from_torch( - torch_weight_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32 + torch_weight_tensor, + weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32, + mesh_mapper=weight_mesh_mapper, ) tt_bias_tensor = None if has_bias: tt_bias_tensor = ttnn.from_torch( - torch_bias_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32 + torch_bias_tensor, + weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32, + mesh_mapper=weight_mesh_mapper, ) - tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) + tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16, mesh_mapper=input_mesh_mapper) if shard_layout is None and not auto_shard: shard_layout = ( @@ -171,11 +184,13 @@ def run_conv( ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) - torch_output_tensor = ttnn.to_torch(tt_output_tensor) + torch_output_tensor = ttnn.to_torch(tt_output_tensor, mesh_composer=output_mesh_composer) # torch_output_tensor is in row major layout and NHWC shape # NHWC to NCHW - torch_output_tensor = torch_output_tensor.reshape(batch_size, out_height, out_width, torch_output_tensor.shape[-1]) + torch_output_tensor = torch_output_tensor.reshape( + total_batch_size, out_height, out_width, torch_output_tensor.shape[-1] + ) torch_output_tensor = torch_output_tensor[:, :, :, :output_channels] torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2)) @@ -288,7 +303,7 @@ def run_conv_with_split( torch_bias_zeroes_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32 ) torch_input_tensor = torch.permute(split_input_tensors[i], (0, 2, 3, 1)) - tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) + tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16, device=device) # tt_input_tensor_on_device = convs[i].copy_input_to_device(tt_input_tensor) # tt_output_tensor_on_device = convs[i](tt_input_tensor_on_device) [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( @@ -417,6 +432,86 @@ def test_conv_features( ) +@skip_for_grayskull() +@skip_for_blackhole() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 2 * 16384}], indirect=True) +@pytest.mark.parametrize("groups", [1, 2]) +@pytest.mark.parametrize("stride", [2]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_channels, input_channels, input_height, input_width, shard_layout, config", + ( + (256, 256, 8, 8, ttnn.TensorMemoryLayout.WIDTH_SHARDED, None), + (128, 128, 32, 32, ttnn.TensorMemoryLayout.BLOCK_SHARDED, None), + (16, 16, 256, 256, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, {"act_block_h": 32}), + ), +) +@pytest.mark.parametrize( + "weights_dtype", + [ttnn.bfloat8_b, ttnn.bfloat16], +) +@pytest.mark.parametrize( + "activations_dtype", + [ttnn.bfloat8_b, ttnn.bfloat16], +) +@pytest.mark.parametrize( + "filter, pad", + [ + [3, 1], + ], +) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) +@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +def test_conv_features_multi_device( + mesh_device, + use_program_cache, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + shard_layout, + config, + filter, + stride, + pad, + output_layout, + groups, +): + if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b: + pytest.skip("Row major layout not compatible with bfloat8_b") + + run_conv( + mesh_device, + math_fidelity, + activations_dtype, + weights_dtype, + batch_size, + output_channels, + input_channels, + input_height, + input_width, + filter, + filter, + stride, + stride, + pad, + pad, + True, + config, + shard_layout=shard_layout, + output_layout=output_layout, + has_bias=True, + input_mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0), + weight_mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + output_mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0), + groups=groups, + ) + + @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize("stride", [1, 2]) diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 0d585ba7b93..b11be6e09ae 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -4,9 +4,9 @@ #include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/distributed/api.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/host_buffer/types.hpp" -#include "ttnn/distributed/api.hpp" namespace tt { @@ -91,7 +91,8 @@ Tensor to_weight_special_padding_tile_layout( conv_weight_tensor.get_storage()); }; - return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) : convert_tensor(conv_weight_tensor); + return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) + : convert_tensor(conv_weight_tensor); } template @@ -175,7 +176,8 @@ Tensor to_weight_tile_layout( }, conv_weight_tensor.get_storage()); }; - return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) : convert_tensor(conv_weight_tensor); + return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) + : convert_tensor(conv_weight_tensor); } // Converts convolution weights to tilized 2d matrix layout. @@ -243,45 +245,61 @@ Helper function to aid in converting grouped weight tensor to ungrouped weight t */ template static Tensor conv_group_weight_zero_pad_helper( - Tensor& conv_weight_tensor, + const Tensor& weight, const ttnn::SimpleShape& original_weight_shape, const ttnn::SimpleShape& output_weight_shape, uint32_t num_groups, DataType output_dtype) { - owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); - auto conv_weight_tensor_buffer = borrowed_buffer::get_as(conv_weight_tensor); - - for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) { - int new_batch_idx = curr_batch_idx; - - // Find which group_id the filter belongs to - through this, we can compute the offset where the padding should - // be applied - auto group_size = original_weight_shape[0] / num_groups; - auto group_index = curr_batch_idx / group_size; - auto group_id = std::min(group_index, num_groups - 1); - int new_channel_start_idx = group_id * original_weight_shape[1]; - - for (int j = 0; j < original_weight_shape[1]; j++) { - for (int k = 0; k < original_weight_shape[2]; k++) { - for (int m = 0; m < original_weight_shape[3]; m++) { - // Get value from original weight tensor - auto value_flat_input_index = - compute_flat_indices(ttnn::SmallVector{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape)); - auto value = conv_weight_tensor_buffer[value_flat_input_index]; - - // Copy value to output tensor at the adjusted position - auto new_channel_idx = new_channel_start_idx + j; - auto output_flat_input_index = compute_flat_indices( - ttnn::SmallVector{new_batch_idx, new_channel_idx, k, m}, compute_strides(output_weight_shape)); - output_buffer[output_flat_input_index] = value; + auto pad_weight = [&original_weight_shape, &output_weight_shape, &num_groups, &output_dtype]( + const auto& conv_weight_tensor_buffer) { + owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); + for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) { + int new_batch_idx = curr_batch_idx; + + // Find which group_id the filter belongs to - through this, we can compute the offset where the padding + // should be applied + auto group_size = original_weight_shape[0] / num_groups; + auto group_index = curr_batch_idx / group_size; + auto group_id = std::min(group_index, num_groups - 1); + int new_channel_start_idx = group_id * original_weight_shape[1]; + + for (int j = 0; j < original_weight_shape[1]; j++) { + for (int k = 0; k < original_weight_shape[2]; k++) { + for (int m = 0; m < original_weight_shape[3]; m++) { + // Get value from original weight tensor + auto value_flat_input_index = compute_flat_indices( + ttnn::SmallVector{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape)); + auto value = conv_weight_tensor_buffer[value_flat_input_index]; + + // Copy value to output tensor at the adjusted position + auto new_channel_idx = new_channel_start_idx + j; + auto output_flat_input_index = compute_flat_indices( + ttnn::SmallVector{new_batch_idx, new_channel_idx, k, m}, + compute_strides(output_weight_shape)); + output_buffer[output_flat_input_index] = value; + } } } } - } + return Tensor( + std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); + }; - auto output_tensor = - Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); - return output_tensor; + auto f = [&](const auto& tensor) { + return std::visit( + [&](auto&& storage) -> Tensor { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return pad_weight(owned_buffer::get_as(storage.buffer)); + } else if constexpr (std::is_same_v) { + return pad_weight(borrowed_buffer::get_as(storage.buffer)); + } else { + TT_THROW("Unsupported storage type"); + } + }, + tensor.get_storage()); + }; + return ttnn::distributed::is_multi_device_tensor(weight) ? transform(weight, f) : f(weight); } /* @@ -300,10 +318,11 @@ static Tensor conv_depthwise_weight_bcast_helper( for (int j = 0; j < output_weight_shape[1]; j++) { for (int k = 0; k < output_weight_shape[2]; k++) { for (int l = 0; l < output_weight_shape[3]; l++) { - auto value_flat_input_index = - compute_flat_indices(ttnn::SmallVector{i, 0, k, l}, compute_strides(original_weight_shape)); + auto value_flat_input_index = compute_flat_indices( + ttnn::SmallVector{i, 0, k, l}, compute_strides(original_weight_shape)); auto value = conv_weight_tensor_buffer[value_flat_input_index]; - auto output_flat_input_index = compute_flat_indices(ttnn::SmallVector{i, j, k, l}, compute_strides(output_weight_shape)); + auto output_flat_input_index = + compute_flat_indices(ttnn::SmallVector{i, j, k, l}, compute_strides(output_weight_shape)); output_buffer[output_flat_input_index] = value; } } @@ -417,40 +436,22 @@ Tensor convert_conv_weight_tensor_to_depthwise_layout( // Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor if (output_dtype == DataType::INT32) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - output_dtype); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); } else if (output_dtype == DataType::FLOAT32) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - output_dtype); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); } else if (output_dtype == DataType::BFLOAT16) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - output_dtype); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); } else if (output_dtype == DataType::UINT16) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - output_dtype); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); } else if (output_dtype == DataType::BFLOAT8_B) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - DataType::FLOAT32); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32); } else { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - DataType::FLOAT32); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32); } TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor"); @@ -464,7 +465,7 @@ const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Sp if (shape[index] == -1) { if (index_of_negative_1 != -1) { std::string error_msg = "Shape cannot have more than 1 elements that is set to -1! Shape used: ("; - for(auto & s: shape) { + for (auto& s : shape) { error_msg += std::to_string(s) + ","; } error_msg += ")"; @@ -519,7 +520,10 @@ void apply(const Tensor& tensor, std::function callable) { std::vector get_devices(const Tensor& tensor) { std::vector devices; if (tensor.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) { - TT_ASSERT(std::holds_alternative(tensor.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(tensor.get_storage())); + TT_ASSERT( + std::holds_alternative(tensor.get_storage()), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(tensor.get_storage())); const auto& tensor_storage = std::get(tensor.get_storage()); for (int i = 0; i < tensor_storage.ordered_device_ids.size(); ++i) { auto device_id = tensor_storage.ordered_device_ids[i]; @@ -630,8 +634,12 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) [&owned_tensor, &tensor](auto&& buffer) { using BorrowedStorageType = std::vector>; auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end())); - owned_tensor = - Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + owned_tensor = Tensor( + OwnedStorage{owned_buf}, + tensor.get_shape(), + tensor.get_dtype(), + tensor.get_layout(), + tensor.get_tile()); }, borrowed_buffer); return owned_tensor;