Skip to content

Commit

Permalink
#0: Add multi-device support in grouped conv2d weight preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Nov 9, 2024
1 parent 2f69888 commit f93e866
Showing 1 changed file with 73 additions and 65 deletions.
138 changes: 73 additions & 65 deletions ttnn/cpp/ttnn/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

#include "ttnn/tensor/tensor_utils.hpp"

#include "ttnn/distributed/api.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/host_buffer/types.hpp"
#include "ttnn/distributed/api.hpp"

namespace tt {

Expand Down Expand Up @@ -91,7 +91,8 @@ Tensor to_weight_special_padding_tile_layout(
conv_weight_tensor.get_storage());
};

return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) : convert_tensor(conv_weight_tensor);
return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor)
: convert_tensor(conv_weight_tensor);
}

template <typename T>
Expand Down Expand Up @@ -175,7 +176,8 @@ Tensor to_weight_tile_layout(
},
conv_weight_tensor.get_storage());
};
return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) : convert_tensor(conv_weight_tensor);
return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor)
: convert_tensor(conv_weight_tensor);
}

// Converts convolution weights to tilized 2d matrix layout.
Expand Down Expand Up @@ -243,45 +245,61 @@ Helper function to aid in converting grouped weight tensor to ungrouped weight t
*/
template <typename T>
static Tensor conv_group_weight_zero_pad_helper(
Tensor& conv_weight_tensor,
const Tensor& weight,
const ttnn::SimpleShape& original_weight_shape,
const ttnn::SimpleShape& output_weight_shape,
uint32_t num_groups,
DataType output_dtype) {
owned_buffer::Buffer<T> output_buffer = owned_buffer::create<T>(output_weight_shape.volume());
auto conv_weight_tensor_buffer = borrowed_buffer::get_as<T>(conv_weight_tensor);

for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) {
int new_batch_idx = curr_batch_idx;

// Find which group_id the filter belongs to - through this, we can compute the offset where the padding should
// be applied
auto group_size = original_weight_shape[0] / num_groups;
auto group_index = curr_batch_idx / group_size;
auto group_id = std::min(group_index, num_groups - 1);
int new_channel_start_idx = group_id * original_weight_shape[1];

for (int j = 0; j < original_weight_shape[1]; j++) {
for (int k = 0; k < original_weight_shape[2]; k++) {
for (int m = 0; m < original_weight_shape[3]; m++) {
// Get value from original weight tensor
auto value_flat_input_index =
compute_flat_indices(ttnn::SmallVector<int>{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape));
auto value = conv_weight_tensor_buffer[value_flat_input_index];

// Copy value to output tensor at the adjusted position
auto new_channel_idx = new_channel_start_idx + j;
auto output_flat_input_index = compute_flat_indices(
ttnn::SmallVector<int>{new_batch_idx, new_channel_idx, k, m}, compute_strides(output_weight_shape));
output_buffer[output_flat_input_index] = value;
auto pad_weight = [&original_weight_shape, &output_weight_shape, &num_groups, &output_dtype](
const auto& conv_weight_tensor_buffer) {
owned_buffer::Buffer<T> output_buffer = owned_buffer::create<T>(output_weight_shape.volume());
for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) {
int new_batch_idx = curr_batch_idx;

// Find which group_id the filter belongs to - through this, we can compute the offset where the padding
// should be applied
auto group_size = original_weight_shape[0] / num_groups;
auto group_index = curr_batch_idx / group_size;
auto group_id = std::min(group_index, num_groups - 1);
int new_channel_start_idx = group_id * original_weight_shape[1];

for (int j = 0; j < original_weight_shape[1]; j++) {
for (int k = 0; k < original_weight_shape[2]; k++) {
for (int m = 0; m < original_weight_shape[3]; m++) {
// Get value from original weight tensor
auto value_flat_input_index = compute_flat_indices(
ttnn::SmallVector<int>{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape));
auto value = conv_weight_tensor_buffer[value_flat_input_index];

// Copy value to output tensor at the adjusted position
auto new_channel_idx = new_channel_start_idx + j;
auto output_flat_input_index = compute_flat_indices(
ttnn::SmallVector<int>{new_batch_idx, new_channel_idx, k, m},
compute_strides(output_weight_shape));
output_buffer[output_flat_input_index] = value;
}
}
}
}
}
return Tensor(
std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR);
};

auto output_tensor =
Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR);
return output_tensor;
auto f = [&](const auto& tensor) {
return std::visit(
[&](auto&& storage) -> Tensor {
using StorageType = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<StorageType, OwnedStorage>) {
return pad_weight(owned_buffer::get_as<T>(storage.buffer));
} else if constexpr (std::is_same_v<StorageType, BorrowedStorage>) {
return pad_weight(borrowed_buffer::get_as<T>(storage.buffer));
} else {
TT_THROW("Unsupported storage type");
}
},
tensor.get_storage());
};
return ttnn::distributed::is_multi_device_tensor(weight) ? transform(weight, f) : f(weight);
}

/*
Expand All @@ -300,10 +318,11 @@ static Tensor conv_depthwise_weight_bcast_helper(
for (int j = 0; j < output_weight_shape[1]; j++) {
for (int k = 0; k < output_weight_shape[2]; k++) {
for (int l = 0; l < output_weight_shape[3]; l++) {
auto value_flat_input_index =
compute_flat_indices(ttnn::SmallVector<int>{i, 0, k, l}, compute_strides(original_weight_shape));
auto value_flat_input_index = compute_flat_indices(
ttnn::SmallVector<int>{i, 0, k, l}, compute_strides(original_weight_shape));
auto value = conv_weight_tensor_buffer[value_flat_input_index];
auto output_flat_input_index = compute_flat_indices(ttnn::SmallVector<int>{i, j, k, l}, compute_strides(output_weight_shape));
auto output_flat_input_index =
compute_flat_indices(ttnn::SmallVector<int>{i, j, k, l}, compute_strides(output_weight_shape));
output_buffer[output_flat_input_index] = value;
}
}
Expand Down Expand Up @@ -417,40 +436,22 @@ Tensor convert_conv_weight_tensor_to_depthwise_layout(
// Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor
if (output_dtype == DataType::INT32) {
return conv_depthwise_weight_bcast_helper<int32_t>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
output_dtype);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype);
} else if (output_dtype == DataType::FLOAT32) {
return conv_depthwise_weight_bcast_helper<float>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
output_dtype);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype);
} else if (output_dtype == DataType::BFLOAT16) {
return conv_depthwise_weight_bcast_helper<bfloat16>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
output_dtype);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype);
} else if (output_dtype == DataType::UINT16) {
return conv_depthwise_weight_bcast_helper<uint16_t>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
output_dtype);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype);
} else if (output_dtype == DataType::BFLOAT8_B) {
return conv_depthwise_weight_bcast_helper<float>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
DataType::FLOAT32);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32);
} else {
return conv_depthwise_weight_bcast_helper<float>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
DataType::FLOAT32);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32);
}

TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor");
Expand All @@ -464,7 +465,7 @@ const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Sp
if (shape[index] == -1) {
if (index_of_negative_1 != -1) {
std::string error_msg = "Shape cannot have more than 1 elements that is set to -1! Shape used: (";
for(auto & s: shape) {
for (auto& s : shape) {
error_msg += std::to_string(s) + ",";
}
error_msg += ")";
Expand Down Expand Up @@ -519,7 +520,10 @@ void apply(const Tensor& tensor, std::function<void(const Tensor&)> callable) {
std::vector<Device*> get_devices(const Tensor& tensor) {
std::vector<Device*> devices;
if (tensor.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) {
TT_ASSERT(std::holds_alternative<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(tensor.get_storage()));
TT_ASSERT(
std::holds_alternative<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage()),
"Unexpected type {}",
tt::stl::get_active_type_name_in_variant(tensor.get_storage()));
const auto& tensor_storage = std::get<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage());
for (int i = 0; i < tensor_storage.ordered_device_ids.size(); ++i) {
auto device_id = tensor_storage.ordered_device_ids[i];
Expand Down Expand Up @@ -630,8 +634,12 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor)
[&owned_tensor, &tensor](auto&& buffer) {
using BorrowedStorageType = std::vector<std::decay_t<decltype(*(buffer.begin()))>>;
auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end()));
owned_tensor =
Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout(), tensor.get_tile());
owned_tensor = Tensor(
OwnedStorage{owned_buf},
tensor.get_shape(),
tensor.get_dtype(),
tensor.get_layout(),
tensor.get_tile());
},
borrowed_buffer);
return owned_tensor;
Expand Down

0 comments on commit f93e866

Please sign in to comment.