Skip to content

Commit

Permalink
#8340: Add functional grouped convolution support
Browse files Browse the repository at this point in the history
  • Loading branch information
tapspatel committed Jun 3, 2024
1 parent 7d68124 commit fd9972a
Show file tree
Hide file tree
Showing 5 changed files with 386 additions and 8 deletions.
221 changes: 219 additions & 2 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@

import torch
import pytest
from models.utility_functions import skip_for_wormhole_b0, skip_for_grayskull, is_grayskull, is_wormhole_b0
from models.utility_functions import (
skip_for_wormhole_b0,
skip_for_grayskull,
is_grayskull,
is_wormhole_b0,
is_x2_harvested,
)
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc, check_with_pcc_without_tensor_printout
import ttnn
import tt_lib
import math
import os
import torch.nn as nn


# def plot_diff(vals, fid, nsticks, stick_len):
Expand Down Expand Up @@ -55,12 +62,13 @@ def run_conv(
output_layout=ttnn.TILE_LAYOUT,
deallocate_activation=False,
debug=False,
groups=1,
):
# has_bias = False
has_bias = True
torch.manual_seed(0)
conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels, filter_height, filter_width]
conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width]
conv_bias_shape = [1, 1, 1, output_channels]
torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float()
torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))
Expand All @@ -72,6 +80,7 @@ def run_conv(
bias=torch_bias_tensor.reshape(-1) if has_bias else None,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
groups=groups,
)
output_shape_nhwc = [
torch_out_golden_tensor.shape[0],
Expand Down Expand Up @@ -123,6 +132,7 @@ def run_conv(
conv_op_cache=reader_patterns_cache,
reshard_if_not_optimal=False,
debug=debug,
groups=groups,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
Expand Down Expand Up @@ -1239,3 +1249,210 @@ def test_conv_core_nondivis(
use_1d_systolic_array,
config_override,
)


# The following test takes various shape sizes from resnet50, unet and stable diffusion and tests for different number of groups - all the way to num_groups = num_in_channels (depthwise conv)
@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, input_channels, output_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, groups, use_1d_systolic_array, config_override, use_shallow_conv_variant",
(
(1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, 2, True, None, False),
(1, 64, 64, 32, 32, 3, 3, 1, 1, 1, 1, 64, True, None, False),
(2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 1, True, None, False),
(2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 2, True, None, False),
(2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, True, None, False),
(1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 1, True, None, False),
(8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 64, True, None, False),
(4, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 128, True, None, False),
(8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 128, True, None, False),
# (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 256, False, None, False), circular buffer error
# (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 256, False, None, False), # doesn't fit with bfloat16 weights
# (32, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 512, False, None, False), # doesn't fit with bfloat16 weights
(32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 40, False, None, False),
(32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 10, False, None, False),
(1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, True, None, False),
(1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 16, True, None, False),
(8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, 32, True, None, False),
(8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 2, False, None, False),
(8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 4, False, None, False),
(1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, 2, False, None, False),
(1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, 320, False, None, False),
# (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, None, False), # doesn't fit with bfloat16 weights
(2, 64, 32, 66, 10, 3, 3, 1, 1, 1, 1, 32, True, None, False),
(2, 32, 96, 132, 20, 3, 3, 1, 1, 1, 1, 2, True, None, False),
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat8_b, ttnn.bfloat16],
)
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT])
def test_conv_groups(
device,
use_program_cache,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
use_shallow_conv_variant,
groups,
output_layout,
):
run_conv(
device,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
groups=groups,
output_layout=output_layout,
)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant, groups",
(
# yolov4 convs with batch size 1
# unique convs in yolov4 (complete list) # groups: number
# (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 32), # groups: 32
# (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 32), # groups: 32
# (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64
# (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64
# (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64
# (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
# (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512
(1, 128, 128, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 2), # groups: 512
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
# [ttnn.bfloat8_b, ttnn.bfloat16],
[ttnn.bfloat8_b],
)
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
# @pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT])
def test_yolov4_conv_groups_larger_than_one(
device,
use_program_cache,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
use_shallow_conv_variant,
groups,
output_layout,
):
if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b:
pytest.skip("Row major layout not compatible with bfloat8_b")
if output_layout == ttnn.ROW_MAJOR_LAYOUT and input_height >= 1056:
pytest.skip("OOM")
run_conv(
device,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
use_1d_systolic_array,
config_override,
use_shallow_conv_variant=use_shallow_conv_variant,
groups=groups,
padded_input_channels=16 if input_channels == 3 else None,
output_layout=output_layout,
)
114 changes: 114 additions & 0 deletions tt_eager/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,120 @@ Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout(
conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype()));
}

/*
Helper function to aid in converting grouped weight tensor to ungrouped weight tensor with padded zero channels
*/
template <typename T>
static Tensor conv_group_weight_zero_pad_helper(
Tensor& conv_weight_tensor,
Shape& original_weight_shape,
Shape& output_weight_shape,
uint32_t num_groups,
DataType output_dtype) {
owned_buffer::Buffer<T> output_buffer = owned_buffer::create<T>(compute_volume(output_weight_shape));
auto conv_weight_tensor_buffer = borrowed_buffer::get_as<T>(conv_weight_tensor);

for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) {
int new_batch_idx = curr_batch_idx;

// Find which group_id the filter belongs to - through this, we can compute the offset where the padding should
// be applied
auto group_size = original_weight_shape[0] / num_groups;
auto group_index = curr_batch_idx / group_size;
auto group_id = std::min(group_index, num_groups - 1);
int new_channel_start_idx = group_id * original_weight_shape[1];

for (int j = 0; j < original_weight_shape[1]; j++) {
for (int k = 0; k < original_weight_shape[2]; k++) {
for (int m = 0; m < original_weight_shape[3]; m++) {
// Get value from original weight tensor
auto value_flat_input_index =
compute_flat_indices({curr_batch_idx, j, k, m}, compute_strides(original_weight_shape));
auto value = conv_weight_tensor_buffer[value_flat_input_index];

// Copy value to output tensor at the adjusted position
auto new_channel_idx = new_channel_start_idx + j;
auto output_flat_input_index = compute_flat_indices(
{new_batch_idx, new_channel_idx, k, m}, compute_strides(output_weight_shape));
output_buffer[output_flat_input_index] = value;
}
}
}
}

auto output_tensor =
Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_weight_shape, output_dtype, Layout::ROW_MAJOR);
return output_tensor;
}

/*
Converts convolution weights to grouped layout with padded zeros
This function will take in a weight tensor with shape [out_channels, in_channels // groups, H, W] and return a newly
allocated output tensor with shape [out_channels, in_channels, H, W] The extra channels in shape[1] will be padded with
0 - then the entire weight tensor is convolved with the input tensor - equivalent to convolution if the input tensor was
divided into num_groups for each groupped filter
*/
Tensor convert_conv_weight_tensor_to_grouped_layout(
Tensor conv_weight_tensor, uint32_t num_groups, DataType output_dtype) {
TT_ASSERT(
conv_weight_tensor.get_layout() == Layout::ROW_MAJOR &&
"Convolution weights should be in row major layout for adding the required padding");

// Define output tensor shape. This is going to be channel dimension of weight tensor * num_groups - this value
// should match number of input channels being convolved with the weight tensor
auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape();
Shape original_conv_weight_tensor_shape = {
original_conv_weight_tensor_shape_test[0],
original_conv_weight_tensor_shape_test[1],
original_conv_weight_tensor_shape_test[2],
original_conv_weight_tensor_shape_test[3]};
Shape output_conv_weight_tensor_shape = {
original_conv_weight_tensor_shape[0],
original_conv_weight_tensor_shape[1] * num_groups,
original_conv_weight_tensor_shape[2],
original_conv_weight_tensor_shape[3]};

// Create newly allocated buffer all initialized to 0 depending on the datatype of the weight tensor
if (output_dtype == DataType::INT32) {
return conv_group_weight_zero_pad_helper<int32_t>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
num_groups,
output_dtype);
} else if (output_dtype == DataType::FLOAT32) {
return conv_group_weight_zero_pad_helper<float>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
num_groups,
output_dtype);
} else if (output_dtype == DataType::BFLOAT16) {
return conv_group_weight_zero_pad_helper<bfloat16>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
num_groups,
output_dtype);
} else if (output_dtype == DataType::UINT16) {
return conv_group_weight_zero_pad_helper<uint16_t>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
num_groups,
output_dtype);
} else {
return conv_group_weight_zero_pad_helper<uint32_t>(
conv_weight_tensor,
original_conv_weight_tensor_shape,
output_conv_weight_tensor_shape,
num_groups,
output_dtype);
}

TT_THROW("Unsupported weight data type given when trying to add zero padding to weight tensor");
}

const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volume) {
vector<int> ns{N, C, H, W};
int neg_idx = -1;
Expand Down
Loading

0 comments on commit fd9972a

Please sign in to comment.