Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multi-device tensors in grouped convolution weight preprocessing #14914

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 106 additions & 11 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,28 @@ def run_conv(
shard_layout=None,
auto_shard=False,
memory_config=None,
input_mesh_mapper=None,
weight_mesh_mapper=None,
output_mesh_composer=None,
):
if isinstance(device, ttnn.MeshDevice):
assert input_mesh_mapper is not None, "Expected mesh mapper for input tensor when using device mesh"
assert weight_mesh_mapper is not None, "Expected mesh mapper for weight tensors when using device mesh"
assert output_mesh_composer is not None, "Expected mesh composer for output tensor when using device mesh"
num_devices = len(device.get_device_ids())
total_batch_size = num_devices * batch_size # Batch size across all devices
logger.info(f"Using {num_devices} devices for this test")
else:
total_batch_size = batch_size

torch.manual_seed(0)
conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_input_shape = [total_batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width]
conv_bias_shape = [1, 1, 1, output_channels]
torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float()
# torch_input_tensor_nchw = torch.ones(conv_input_shape, dtype=torch.bfloat16).float()
# torch_input_tensor_nchw = torch.tensor(range(input_height * input_width)).reshape([1,1,input_height,input_width]).float()
# torch_input_tensor_nchw = torch_input_tensor_nchw.broadcast_to(conv_input_shape).float()

torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float()
# torch_weight_tensor = torch.ones(conv_weight_shape, dtype=torch.bfloat16).float()

torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None
torch_out_golden_tensor = torch.nn.functional.conv2d(
Expand All @@ -107,15 +116,19 @@ def run_conv(
reader_patterns_cache = {}

tt_weight_tensor = ttnn.from_torch(
torch_weight_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
torch_weight_tensor,
weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32,
mesh_mapper=weight_mesh_mapper,
)
tt_bias_tensor = None
if has_bias:
tt_bias_tensor = ttnn.from_torch(
torch_bias_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
torch_bias_tensor,
weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32,
mesh_mapper=weight_mesh_mapper,
)

tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)
tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16, mesh_mapper=input_mesh_mapper)

if shard_layout is None and not auto_shard:
shard_layout = (
Expand Down Expand Up @@ -171,11 +184,13 @@ def run_conv(
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)
torch_output_tensor = ttnn.to_torch(tt_output_tensor, mesh_composer=output_mesh_composer)

# torch_output_tensor is in row major layout and NHWC shape
# NHWC to NCHW
torch_output_tensor = torch_output_tensor.reshape(batch_size, out_height, out_width, torch_output_tensor.shape[-1])
torch_output_tensor = torch_output_tensor.reshape(
total_batch_size, out_height, out_width, torch_output_tensor.shape[-1]
)
torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]

torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2))
Expand Down Expand Up @@ -288,7 +303,7 @@ def run_conv_with_split(
torch_bias_zeroes_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
)
torch_input_tensor = torch.permute(split_input_tensors[i], (0, 2, 3, 1))
tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)
tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16, device=device)
# tt_input_tensor_on_device = convs[i].copy_input_to_device(tt_input_tensor)
# tt_output_tensor_on_device = convs[i](tt_input_tensor_on_device)
[tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
Expand Down Expand Up @@ -417,6 +432,86 @@ def test_conv_features(
)


@skip_for_grayskull()
@skip_for_blackhole()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 2 * 16384}], indirect=True)
@pytest.mark.parametrize("groups", [1, 2])
@pytest.mark.parametrize("stride", [2])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_channels, input_channels, input_height, input_width, shard_layout, config",
(
(256, 256, 8, 8, ttnn.TensorMemoryLayout.WIDTH_SHARDED, None),
(128, 128, 32, 32, ttnn.TensorMemoryLayout.BLOCK_SHARDED, None),
(16, 16, 256, 256, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, {"act_block_h": 32}),
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b, ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat8_b, ttnn.bfloat16],
)
@pytest.mark.parametrize(
"filter, pad",
[
[3, 1],
],
)
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_conv_features_multi_device(
mesh_device,
use_program_cache,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
shard_layout,
config,
filter,
stride,
pad,
output_layout,
groups,
):
if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b:
pytest.skip("Row major layout not compatible with bfloat8_b")

run_conv(
mesh_device,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter,
filter,
stride,
stride,
pad,
pad,
True,
config,
shard_layout=shard_layout,
output_layout=output_layout,
has_bias=True,
input_mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0),
weight_mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
output_mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0),
groups=groups,
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("stride", [1, 2])
Expand Down
138 changes: 73 additions & 65 deletions ttnn/cpp/ttnn/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

#include "ttnn/tensor/tensor_utils.hpp"

#include "ttnn/distributed/api.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/host_buffer/types.hpp"
#include "ttnn/distributed/api.hpp"

namespace tt {

Expand Down Expand Up @@ -91,7 +91,8 @@ Tensor to_weight_special_padding_tile_layout(
conv_weight_tensor.get_storage());
};

return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) : convert_tensor(conv_weight_tensor);
return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor)
: convert_tensor(conv_weight_tensor);
}

template <typename T>
Expand Down Expand Up @@ -175,7 +176,8 @@ Tensor to_weight_tile_layout(
},
conv_weight_tensor.get_storage());
};
return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor) : convert_tensor(conv_weight_tensor);
return ttnn::distributed::is_multi_device_tensor(conv_weight_tensor) ? transform(conv_weight_tensor, convert_tensor)
: convert_tensor(conv_weight_tensor);
}

// Converts convolution weights to tilized 2d matrix layout.
Expand Down Expand Up @@ -243,45 +245,61 @@ Helper function to aid in converting grouped weight tensor to ungrouped weight t
*/
template <typename T>
static Tensor conv_group_weight_zero_pad_helper(
Tensor& conv_weight_tensor,
const Tensor& weight,
const ttnn::SimpleShape& original_weight_shape,
const ttnn::SimpleShape& output_weight_shape,
uint32_t num_groups,
DataType output_dtype) {
owned_buffer::Buffer<T> output_buffer = owned_buffer::create<T>(output_weight_shape.volume());
auto conv_weight_tensor_buffer = borrowed_buffer::get_as<T>(conv_weight_tensor);

for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) {
int new_batch_idx = curr_batch_idx;

// Find which group_id the filter belongs to - through this, we can compute the offset where the padding should
// be applied
auto group_size = original_weight_shape[0] / num_groups;
auto group_index = curr_batch_idx / group_size;
auto group_id = std::min(group_index, num_groups - 1);
int new_channel_start_idx = group_id * original_weight_shape[1];

for (int j = 0; j < original_weight_shape[1]; j++) {
for (int k = 0; k < original_weight_shape[2]; k++) {
for (int m = 0; m < original_weight_shape[3]; m++) {
// Get value from original weight tensor
auto value_flat_input_index =
compute_flat_indices(ttnn::SmallVector<int>{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape));
auto value = conv_weight_tensor_buffer[value_flat_input_index];

// Copy value to output tensor at the adjusted position
auto new_channel_idx = new_channel_start_idx + j;
auto output_flat_input_index = compute_flat_indices(
ttnn::SmallVector<int>{new_batch_idx, new_channel_idx, k, m}, compute_strides(output_weight_shape));
output_buffer[output_flat_input_index] = value;
auto pad_weight = [&original_weight_shape, &output_weight_shape, &num_groups, &output_dtype](
const auto& conv_weight_tensor_buffer) {
owned_buffer::Buffer<T> output_buffer = owned_buffer::create<T>(output_weight_shape.volume());
for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) {
int new_batch_idx = curr_batch_idx;

// Find which group_id the filter belongs to - through this, we can compute the offset where the padding
// should be applied
auto group_size = original_weight_shape[0] / num_groups;
auto group_index = curr_batch_idx / group_size;
auto group_id = std::min(group_index, num_groups - 1);
int new_channel_start_idx = group_id * original_weight_shape[1];

for (int j = 0; j < original_weight_shape[1]; j++) {
for (int k = 0; k < original_weight_shape[2]; k++) {
for (int m = 0; m < original_weight_shape[3]; m++) {
// Get value from original weight tensor
auto value_flat_input_index = compute_flat_indices(
ttnn::SmallVector<int>{curr_batch_idx, j, k, m}, compute_strides(original_weight_shape));
auto value = conv_weight_tensor_buffer[value_flat_input_index];

// Copy value to output tensor at the adjusted position
auto new_channel_idx = new_channel_start_idx + j;
auto output_flat_input_index = compute_flat_indices(
ttnn::SmallVector<int>{new_batch_idx, new_channel_idx, k, m},
compute_strides(output_weight_shape));
output_buffer[output_flat_input_index] = value;
}
}
}
}
}
return Tensor(
std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR);
};

auto output_tensor =
Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR);
return output_tensor;
auto f = [&](const auto& tensor) {
return std::visit(
[&](auto&& storage) -> Tensor {
using StorageType = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<StorageType, OwnedStorage>) {
return pad_weight(owned_buffer::get_as<T>(storage.buffer));
} else if constexpr (std::is_same_v<StorageType, BorrowedStorage>) {
return pad_weight(borrowed_buffer::get_as<T>(storage.buffer));
} else {
TT_THROW("Unsupported storage type");
}
},
tensor.get_storage());
};
return ttnn::distributed::is_multi_device_tensor(weight) ? transform(weight, f) : f(weight);
}

/*
Expand All @@ -300,10 +318,11 @@ static Tensor conv_depthwise_weight_bcast_helper(
for (int j = 0; j < output_weight_shape[1]; j++) {
for (int k = 0; k < output_weight_shape[2]; k++) {
for (int l = 0; l < output_weight_shape[3]; l++) {
auto value_flat_input_index =
compute_flat_indices(ttnn::SmallVector<int>{i, 0, k, l}, compute_strides(original_weight_shape));
auto value_flat_input_index = compute_flat_indices(
ttnn::SmallVector<int>{i, 0, k, l}, compute_strides(original_weight_shape));
auto value = conv_weight_tensor_buffer[value_flat_input_index];
auto output_flat_input_index = compute_flat_indices(ttnn::SmallVector<int>{i, j, k, l}, compute_strides(output_weight_shape));
auto output_flat_input_index =
compute_flat_indices(ttnn::SmallVector<int>{i, j, k, l}, compute_strides(output_weight_shape));
output_buffer[output_flat_input_index] = value;
}
}
Expand Down Expand Up @@ -417,40 +436,22 @@ Tensor convert_conv_weight_tensor_to_depthwise_layout(
// Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor
if (output_dtype == DataType::INT32) {
return conv_depthwise_weight_bcast_helper<int32_t>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
output_dtype);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype);
} else if (output_dtype == DataType::FLOAT32) {
return conv_depthwise_weight_bcast_helper<float>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
output_dtype);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype);
} else if (output_dtype == DataType::BFLOAT16) {
return conv_depthwise_weight_bcast_helper<bfloat16>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
output_dtype);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype);
} else if (output_dtype == DataType::UINT16) {
return conv_depthwise_weight_bcast_helper<uint16_t>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
output_dtype);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, output_dtype);
} else if (output_dtype == DataType::BFLOAT8_B) {
return conv_depthwise_weight_bcast_helper<float>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
DataType::FLOAT32);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32);
} else {
return conv_depthwise_weight_bcast_helper<float>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
DataType::FLOAT32);
conv_weight_tensor, original_conv_weight_tensor_shape, output_conv_weight_tensor_shape, DataType::FLOAT32);
}

TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor");
Expand All @@ -464,7 +465,7 @@ const ttnn::SimpleShape infer_dims_for_reshape(const Tensor& tensor, tt::stl::Sp
if (shape[index] == -1) {
if (index_of_negative_1 != -1) {
std::string error_msg = "Shape cannot have more than 1 elements that is set to -1! Shape used: (";
for(auto & s: shape) {
for (auto& s : shape) {
error_msg += std::to_string(s) + ",";
}
error_msg += ")";
Expand Down Expand Up @@ -519,7 +520,10 @@ void apply(const Tensor& tensor, std::function<void(const Tensor&)> callable) {
std::vector<Device*> get_devices(const Tensor& tensor) {
std::vector<Device*> devices;
if (tensor.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) {
TT_ASSERT(std::holds_alternative<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(tensor.get_storage()));
TT_ASSERT(
std::holds_alternative<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage()),
"Unexpected type {}",
tt::stl::get_active_type_name_in_variant(tensor.get_storage()));
const auto& tensor_storage = std::get<tt::tt_metal::MultiDeviceStorage>(tensor.get_storage());
for (int i = 0; i < tensor_storage.ordered_device_ids.size(); ++i) {
auto device_id = tensor_storage.ordered_device_ids[i];
Expand Down Expand Up @@ -630,8 +634,12 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor)
[&owned_tensor, &tensor](auto&& buffer) {
using BorrowedStorageType = std::vector<std::decay_t<decltype(*(buffer.begin()))>>;
auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end()));
owned_tensor =
Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout(), tensor.get_tile());
owned_tensor = Tensor(
OwnedStorage{owned_buf},
tensor.get_shape(),
tensor.get_dtype(),
tensor.get_layout(),
tensor.get_tile());
},
borrowed_buffer);
return owned_tensor;
Expand Down
Loading