diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 0d585ba7b93..b11be6e09ae 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -4,9 +4,9 @@ #include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/distributed/api.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/host_buffer/types.hpp" -#include "ttnn/distributed/api.hpp" namespace tt { @@ -91,7 +91,8 @@ Tensor to_weight_special_padding_tile_layout( conv_weight_tensor.get_storage()); }; - return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) : convert_tensor(conv_weight_tensor); + return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) + : convert_tensor(conv_weight_tensor); } template @@ -175,7 +176,8 @@ Tensor to_weight_tile_layout( }, conv_weight_tensor.get_storage()); }; - return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) : convert_tensor(conv_weight_tensor); + return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) + : convert_tensor(conv_weight_tensor); } // Converts convolution weights to tilized 2d matrix layout. @@ -243,45 +245,61 @@ Helper function to aid in converting grouped weight tensor to ungrouped weight t */ template static Tensor conv_group_weight_zero_pad_helper( - Tensor& conv_weight_tensor, + const Tensor& weight, const ttnn::SimpleShape& original_weight_shape, const ttnn::SimpleShape& output_weight_shape, uint32_t num_groups, DataType output_dtype) { - owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); - auto conv_weight_tensor_buffer = borrowed_buffer::get_as(conv_weight_tensor); - - for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) { - int new_batch_idx = curr_batch_idx; - - // Find which group_id the filter belongs to - through this, we can compute the offset where the padding should - // be applied - auto group_size = original_weight_shape[0] / num_groups; - auto group_index = curr_batch_idx / group_size; - auto group_id = std::min(group_index, num_groups - 1); - int new_channel_start_idx = group_id * original_weight_shape[1]; - - for (int j = 0; j < original_weight_shape[1]; j++) { - for (int k = 0; k < original_weight_shape[2]; k++) { - for (int m = 0; m < original_weight_shape[3]; m++) { - // Get value from original weight tensor - auto value_flat_input_index = - compute_flat_indices(ttnn::SmallVector{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape)); - auto value = conv_weight_tensor_buffer[value_flat_input_index]; - - // Copy value to output tensor at the adjusted position - auto new_channel_idx = new_channel_start_idx + j; - auto output_flat_input_index = compute_flat_indices( - ttnn::SmallVector{new_batch_idx, new_channel_idx, k, m}, compute_strides(output_weight_shape)); - output_buffer[output_flat_input_index] = value; + auto pad_weight = [&original_weight_shape, &output_weight_shape, &num_groups, &output_dtype]( + const auto& conv_weight_tensor_buffer) { + owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); + for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) { + int new_batch_idx = curr_batch_idx; + + // Find which group_id the filter belongs to - through this, we can compute the offset where the padding + // should be applied + auto group_size = original_weight_shape[0] / num_groups; + auto group_index = curr_batch_idx / group_size; + auto group_id = std::min(group_index, num_groups - 1); + int new_channel_start_idx = group_id * original_weight_shape[1]; + + for (int j = 0; j < original_weight_shape[1]; j++) { + for (int k = 0; k < original_weight_shape[2]; k++) { + for (int m = 0; m < original_weight_shape[3]; m++) { + // Get value from original weight tensor + auto value_flat_input_index = compute_flat_indices( + ttnn::SmallVector{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape)); + auto value = conv_weight_tensor_buffer[value_flat_input_index]; + + // Copy value to output tensor at the adjusted position + auto new_channel_idx = new_channel_start_idx + j; + auto output_flat_input_index = compute_flat_indices( + ttnn::SmallVector{new_batch_idx, new_channel_idx, k, m}, + compute_strides(output_weight_shape)); + output_buffer[output_flat_input_index] = value; + } } } } - } + return Tensor( + std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); + }; - auto output_tensor = - Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR); - return output_tensor; + auto f = [&](const auto& tensor) { + return std::visit( + [&](auto&& storage) -> Tensor { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return pad_weight(owned_buffer::get_as(storage.buffer)); + } else if constexpr (std::is_same_v) { + return pad_weight(borrowed_buffer::get_as(storage.buffer)); + } else { + TT_THROW("Unsupported storage type"); + } + }, + tensor.get_storage()); + }; + return ttnn::distributed::is_multi_device_tensor(weight) ? transform(weight, f) : f(weight); } /* @@ -300,10 +318,11 @@ static Tensor conv_depthwise_weight_bcast_helper( for (int j = 0; j < output_weight_shape[1]; j++) { for (int k = 0; k < output_weight_shape[2]; k++) { for (int l = 0; l < output_weight_shape[3]; l++) { - auto value_flat_input_index = - compute_flat_indices(ttnn::SmallVector{i, 0, k, l}, compute_strides(original_weight_shape)); + auto value_flat_input_index = compute_flat_indices( + ttnn::SmallVector{i, 0, k, l}, compute_strides(original_weight_shape)); auto value = conv_weight_tensor_buffer[value_flat_input_index]; - auto output_flat_input_index = compute_flat_indices(ttnn::SmallVector{i, j, k, l}, compute_strides(output_weight_shape)); + auto output_flat_input_index = + compute_flat_indices(ttnn::SmallVector{i, j, k, l}, compute_strides(output_weight_shape)); output_buffer[output_flat_input_index] = value; } } @@ -417,40 +436,22 @@ Tensor convert_conv_weight_tensor_to_depthwise_layout( // Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor if (output_dtype == DataType::INT32) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - output_dtype); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); } else if (output_dtype == DataType::FLOAT32) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - output_dtype); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); } else if (output_dtype == DataType::BFLOAT16) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - output_dtype); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); } else if (output_dtype == DataType::UINT16) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - output_dtype); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype); } else if (output_dtype == DataType::BFLOAT8_B) { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - DataType::FLOAT32); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32); } else { return conv_depthwise_weight_bcast_helper( - conv_weight_tensor, - original_conv_weight_tensor_shape, - output_conv_weight_tensor_shape, - DataType::FLOAT32); + conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32); } TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor"); @@ -464,7 +465,7 @@ const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Sp if (shape[index] == -1) { if (index_of_negative_1 != -1) { std::string error_msg = "Shape cannot have more than 1 elements that is set to -1! Shape used: ("; - for(auto & s: shape) { + for (auto& s : shape) { error_msg += std::to_string(s) + ","; } error_msg += ")"; @@ -519,7 +520,10 @@ void apply(const Tensor& tensor, std::function callable) { std::vector get_devices(const Tensor& tensor) { std::vector devices; if (tensor.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) { - TT_ASSERT(std::holds_alternative(tensor.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(tensor.get_storage())); + TT_ASSERT( + std::holds_alternative(tensor.get_storage()), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(tensor.get_storage())); const auto& tensor_storage = std::get(tensor.get_storage()); for (int i = 0; i < tensor_storage.ordered_device_ids.size(); ++i) { auto device_id = tensor_storage.ordered_device_ids[i]; @@ -630,8 +634,12 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) [&owned_tensor, &tensor](auto&& buffer) { using BorrowedStorageType = std::vector>; auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end())); - owned_tensor = - Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); + owned_tensor = Tensor( + OwnedStorage{owned_buf}, + tensor.get_shape(), + tensor.get_dtype(), + tensor.get_layout(), + tensor.get_tile()); }, borrowed_buffer); return owned_tensor;