Skip to content

Commit

Permalink
Add fix for monodle failure (#629)
Browse files Browse the repository at this point in the history
  • Loading branch information
meenakshiramanathan1 authored Nov 13, 2024
1 parent 1bee7d0 commit bb22c1b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 33 deletions.
41 changes: 17 additions & 24 deletions forge/forge/op/eval/forge/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,57 +226,52 @@ def eval(self, tensors):
assert len(tensors) <= 3, "ConvTranspose ops should have up to three inputs (input, weight, bias)"
assert len(tensors) >= 2, "ConvTranspose ops should have at least two inputs (input, weight)"
t_ops = to_torch_operands(*tensors)

activations = t_ops[0]
weights = t_ops[1]
bias = t_ops[2] if len(t_ops) == 3 else None

stride = [self.stride_height, self.stride_width]
dilation = [self.dilation_height, self.dilation_width]
groups = self.groups
padding = [
self.padding_left,
self.padding_right,
self.padding_top,
self.padding_bottom,
]
# TODO: Add support for asymmetric padding cases in convtranspose2d
if self.padding_left == self.padding_right and self.padding_top == self.padding_bottom:
padding = (self.padding_top, self.padding_left)
else:
assert (
False
), "Currently, different left and right padding or different top and bottom padding isn't supported. Please check out this issue (https://github.com/tenstorrent/tt-forge-fe/issues/665) for more details."

channel_last = self.channel_last
if channel_last:
activations = activations.permute((0, 3, 1, 2))

padded_activations = torch.nn.functional.pad(
activations,
padding,
)
if t_ops[1].dtype == torch.int8:
target_dtype = torch.int32
padded_activations, weights = padded_activations.float(), weights.float()
activations, weights = activations.float(), weights.float()
if bias is not None:
bias = bias.float()
else:
target_dtype = torch.float32

result = torch.nn.functional.conv_transpose2d(
padded_activations,
activations,
weights,
bias=bias,
stride=stride,
padding=0,
padding=padding,
dilation=dilation,
groups=groups,
)

if channel_last:
result = result.permute((0, 2, 3, 1))

result = result.to(target_dtype)
return result

def shape(self, tensor_shapes):
act, weight = tensor_shapes[:2]
batch_size = act[0]
cout = weight[1]
cout = weight[1] * self.groups

h_in = act[-3] if self.channel_last else act[-2]
w_in = act[-2] if self.channel_last else act[-1]
Expand All @@ -285,22 +280,20 @@ def shape(self, tensor_shapes):
output_padding_width = 0

h_out = (
((h_in - 1) * self.stride_height)
- (2 * (self.padding_top + self.padding_bottom))
+ (self.dilation_height * (weight[-2] - 1))
(h_in - 1) * self.stride_height
- (self.padding_top + self.padding_bottom)
+ self.dilation_height * (weight[-2] - 1)
+ output_padding_height
+ 1
)
w_out = (
((w_in - 1) * self.stride_width)
- (2 * (self.padding_left + self.padding_right))
+ (self.dilation_width * (weight[-1] - 1))
(w_in - 1) * self.stride_width
- (self.padding_left + self.padding_right)
+ self.dilation_width * (weight[-1] - 1)
+ output_padding_width
+ 1
)

out_shape = [batch_size, h_out, w_out, cout] if self.channel_last else [batch_size, cout, h_out, w_out]

return out_shape, []

def decompose(self, dc, inputs):
Expand Down
9 changes: 8 additions & 1 deletion forge/forge/tvm_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,14 @@ def populate_conv2d_transpose_args(graph, nid, compiler_cfg):
)
)

in_channel = next((n["attrs"]["shape"][0][0][0] for n in graph["nodes"] if n["name"] == "model.weight"), None)
in_channel = None
for input_ in node["inputs"]:
input_nid = input_[0]
input_node = graph["nodes"][input_nid]
if input_node["op"] == "parameter" and input_node["name"].endswith("weight"):
in_channel = input_node["attrs"]["shape"][0][0][0]
break

groups = int(node["attrs"]["groups"][0][0])
assert groups == 1 or (in_channel is not None and groups == in_channel), "Only supports group of 1 or in_channel"
args.append(
Expand Down
23 changes: 16 additions & 7 deletions forge/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,16 +1374,25 @@ def forward(self, a):

@pytest.mark.xfail(reason="Found Unsupported operations while lowering from TTForge to TTIR in forward graph")
@pytest.mark.parametrize(
"in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode",
"in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode, input_shape",
[
(16, 33, (3, 3), 2, 0, 1, True, 1, "zeros"),
(16, 33, (3, 3), 2, 0, 1, False, 1, "zeros"),
(16, 33, (3, 5), 2, 0, 1, True, 1, "zeros"),
(16, 33, (3, 3), 2, 0, 1, True, 1, "zeros", (16, 50, 100)),
(16, 32, (3, 5), 2, 1, 1, True, 1, "zeros", (16, 50, 100)),
(16, 16, (3, 3), 1, 1, 16, True, 1, "zeros", (16, 50, 100)),
(16, 33, (3, 3), 2, 0, 1, True, 1, "zeros", (20, 16, 50, 100)),
(16, 33, (3, 3), 2, 0, 1, False, 1, "zeros", (20, 16, 50, 100)),
(16, 33, (3, 5), 2, 0, 1, True, 1, "zeros", (20, 16, 50, 100)),
(16, 16, (5, 5), 1, 2, 1, True, 1, "zeros", (20, 16, 50, 100)),
(16, 32, (3, 5), 2, 1, 1, True, 1, "zeros", (20, 16, 50, 100)),
(16, 32, (3, 3), 4, 1, 1, False, 1, "zeros", (20, 16, 50, 100)),
(16, 16, (3, 3), 2, 2, 1, True, 1, "zeros", (20, 16, 50, 100)),
(16, 16, (3, 3), 1, 1, 16, True, 1, "zeros", (20, 16, 50, 100)),
],
)
@pytest.mark.push
def test_convtranspose2d(in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode):
inputs = [torch.randn(20, 16, 50, 100)]
def test_convtranspose2d(
in_channels, out_channels, kernel_size, stride, padding, groups, bias, dilation, padding_mode, input_shape
):
inputs = [torch.randn(*input_shape)]

framework_model = torch.nn.ConvTranspose2d(
in_channels=in_channels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def test_monodle_pytorch(test_device):
# PyBuda configuration parameters
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

model_name = "monodle_pytorch"

Expand Down

0 comments on commit bb22c1b

Please sign in to comment.