Skip to content

Commit

Permalink
Fix stack op sanity failure
Browse files Browse the repository at this point in the history
  • Loading branch information
meenakshiramanathan1 committed Nov 6, 2024
1 parent 8eb266a commit 5d7e4ef
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
5 changes: 4 additions & 1 deletion forge/csrc/passes/commute_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ std::tuple<bool, int> can_commute_reshape_through_dim(
{
if (input_shape_vec[i] == output_shape_vec[dim])
{
log_trace("output_shape_vec.size() = {}", output_shape_vec.size());
log_trace("i value = {}", i);
// check whether volume above and below matching dim is the same
if ((volume_above(input_shape_vec, i) == volume_above(output_shape_vec, dim)) and
(volume_below(input_shape_vec, i) == volume_below(output_shape_vec, dim)))
(volume_below(input_shape_vec, i) == volume_below(output_shape_vec, dim)) and
(i < output_shape_vec.size()))
{
can_commute = true;
new_dim = i;
Expand Down
19 changes: 14 additions & 5 deletions forge/forge/op/eval/forge/eltwise_nary.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,13 @@ def get_eltwise_shape_and_broadcast():

elif type == "stack":
axis = attr[0]

output_shape = list(ops[0])
output_shape.insert(axis, len(ops))
if axis == -1:
output_shape.append(len(ops))
elif axis < 0:
output_shape.insert(axis + 1, len(ops))
else:
output_shape.insert(axis, len(ops))
return output_shape, []

elif type == "interleave":
Expand Down Expand Up @@ -246,11 +250,16 @@ def decompose(type, attr, dc, inputs):
new_inputs = []
for inp in inputs:
inp_shape = inp.shape.as_list()
inp_shape.insert(axis, 1)
new_inp = dc.op("reshape", [inp], (*inp_shape,))
if axis == -1:
inp_shape.append(len(ops))
elif axis < 0:
inp_shape.insert(axis + 1, len(ops))
else:
inp_shape.insert(axis, len(ops))
new_inp = dc.op_with_named_attrs("reshape", [inp], {"shape": (*inp_shape,)}, (*inp_shape,))
new_inputs.append(new_inp)

output = dc.op("concatenate", new_inputs, (axis,))
output = dc.op_with_named_attrs("concatenate", new_inputs, {"dim": (axis)}, (axis,))
dc.fuse(output)

if type == "concatenate":
Expand Down
43 changes: 43 additions & 0 deletions forge/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,49 @@ def forward(self, x):
]


@pytest.mark.parametrize(
"params",
[
([(1, 256, 24, 24), (1, 256, 24, 24)], -4),
([(5, 64, 128, 128), (5, 64, 128, 128)], -3),
([(1, 30, 30, 16), (1, 30, 30, 16)], -2),
([(1, 30, 30, 16), (1, 30, 30, 16)], 3),
([(5, 64, 128, 128), (5, 64, 128, 128)], -1),
([(1, 256, 24, 24), (1, 256, 24, 24)], 4),
([(1, 256, 24, 24), (1, 256, 24, 24)], 2),
([(5, 64, 128, 128), (5, 64, 128, 128)], 1),
([(1, 30, 30, 16), (1, 30, 30, 16)], 0),
],
)
def test_stack(params):
class Stack(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, *tensors):
return torch.stack(tensors, dim=self.dim)

input_shapes, dim = params
if dim == -3 or dim == 1:
pytest.xfail("Tensor rank is not 4")
else:
pytest.xfail(
"Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first"
)
inputs = [torch.rand(shape) for shape in input_shapes]

framework_model = Stack(dim)
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name="stack_sanity")
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])


@pytest.mark.parametrize("shape", shapes)
def test_exp(shape):
class exp(nn.Module):
Expand Down

0 comments on commit 5d7e4ef

Please sign in to comment.