diff --git a/forge/csrc/passes/commute_utils.cpp b/forge/csrc/passes/commute_utils.cpp index 8059a581..6d6e2899 100644 --- a/forge/csrc/passes/commute_utils.cpp +++ b/forge/csrc/passes/commute_utils.cpp @@ -45,9 +45,12 @@ std::tuple can_commute_reshape_through_dim( { if (input_shape_vec[i] == output_shape_vec[dim]) { + log_trace("output_shape_vec.size() = {}", output_shape_vec.size()); + log_trace("i value = {}", i); // check whether volume above and below matching dim is the same if ((volume_above(input_shape_vec, i) == volume_above(output_shape_vec, dim)) and - (volume_below(input_shape_vec, i) == volume_below(output_shape_vec, dim))) + (volume_below(input_shape_vec, i) == volume_below(output_shape_vec, dim)) and + (i < output_shape_vec.size())) { can_commute = true; new_dim = i; diff --git a/forge/forge/op/eval/forge/eltwise_nary.py b/forge/forge/op/eval/forge/eltwise_nary.py index 684d1d09..c20f7868 100644 --- a/forge/forge/op/eval/forge/eltwise_nary.py +++ b/forge/forge/op/eval/forge/eltwise_nary.py @@ -142,9 +142,13 @@ def get_eltwise_shape_and_broadcast(): elif type == "stack": axis = attr[0] - output_shape = list(ops[0]) - output_shape.insert(axis, len(ops)) + if axis == -1: + output_shape.append(len(ops)) + elif axis < 0: + output_shape.insert(axis + 1, len(ops)) + else: + output_shape.insert(axis, len(ops)) return output_shape, [] elif type == "interleave": @@ -246,11 +250,16 @@ def decompose(type, attr, dc, inputs): new_inputs = [] for inp in inputs: inp_shape = inp.shape.as_list() - inp_shape.insert(axis, 1) - new_inp = dc.op("reshape", [inp], (*inp_shape,)) + if axis == -1: + inp_shape.append(len(ops)) + elif axis < 0: + inp_shape.insert(axis + 1, len(ops)) + else: + inp_shape.insert(axis, len(ops)) + new_inp = dc.op_with_named_attrs("reshape", [inp], {"shape": (*inp_shape,)}, (*inp_shape,)) new_inputs.append(new_inp) - output = dc.op("concatenate", new_inputs, (axis,)) + output = dc.op_with_named_attrs("concatenate", new_inputs, {"dim": (axis)}, (axis,)) dc.fuse(output) if type == "concatenate": diff --git a/forge/test/mlir/test_ops.py b/forge/test/mlir/test_ops.py index 4ac41e6c..ec5c9d5a 100644 --- a/forge/test/mlir/test_ops.py +++ b/forge/test/mlir/test_ops.py @@ -49,6 +49,49 @@ def forward(self, x): ] +@pytest.mark.parametrize( + "params", + [ + ([(1, 256, 24, 24), (1, 256, 24, 24)], -4), + ([(5, 64, 128, 128), (5, 64, 128, 128)], -3), + ([(1, 30, 30, 16), (1, 30, 30, 16)], -2), + ([(1, 30, 30, 16), (1, 30, 30, 16)], 3), + ([(5, 64, 128, 128), (5, 64, 128, 128)], -1), + ([(1, 256, 24, 24), (1, 256, 24, 24)], 4), + ([(1, 256, 24, 24), (1, 256, 24, 24)], 2), + ([(5, 64, 128, 128), (5, 64, 128, 128)], 1), + ([(1, 30, 30, 16), (1, 30, 30, 16)], 0), + ], +) +def test_stack(params): + class Stack(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, *tensors): + return torch.stack(tensors, dim=self.dim) + + input_shapes, dim = params + if dim == -3 or dim == 1: + pytest.xfail("Tensor rank is not 4") + else: + pytest.xfail( + "Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first" + ) + inputs = [torch.rand(shape) for shape in input_shapes] + + framework_model = Stack(dim) + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name="stack_sanity") + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + @pytest.mark.parametrize("shape", shapes) def test_exp(shape): class exp(nn.Module):