Skip to content

Commit

Permalink
Fix codegen failure
Browse files Browse the repository at this point in the history
  • Loading branch information
meenakshiramanathan1 committed Nov 15, 2024
1 parent 4b24864 commit 62e0def
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 6 deletions.
20 changes: 16 additions & 4 deletions forge/forge/op/eval/forge/eltwise_nary.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,13 @@ def get_eltwise_shape_and_broadcast():
axis = attr[0]

output_shape = list(ops[0])
output_shape.insert(axis, len(ops))
if axis == -1:
output_shape.append(len(ops))
elif axis < -1:
# axis + 1 is added because insertion at the correct position requires adjusting for negative axes to ensure proper behavior.
output_shape.insert(axis + 1, len(ops))
else:
output_shape.insert(axis, len(ops))
return output_shape, []

elif type == "interleave":
Expand Down Expand Up @@ -246,11 +252,17 @@ def decompose(type, attr, dc, inputs):
new_inputs = []
for inp in inputs:
inp_shape = inp.shape.as_list()
inp_shape.insert(axis, 1)
new_inp = dc.op("reshape", [inp], (*inp_shape,))
if axis == -1:
inp_shape.append(1)
elif axis < -1:
# axis + 1 is added because insertion at the correct position requires adjusting for negative axes to ensure proper behavior.
inp_shape.insert(axis + 1, 1)
else:
inp_shape.insert(axis, 1)
new_inp = dc.op_with_named_attrs("reshape", [inp], {"shape": (*inp_shape,)}, (*inp_shape,))
new_inputs.append(new_inp)

output = dc.op("concatenate", new_inputs, (axis,))
output = dc.op_with_named_attrs("concatenate", new_inputs, {"dim": (axis)}, (axis,))
dc.fuse(output)

if type == "concatenate":
Expand Down
2 changes: 2 additions & 0 deletions forge/forge/python_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def pytorch_df_str_from_str(df: str, name):
return "torch.int16"
elif df == "int64":
return "torch.int64"
elif df == "uint1":
return "torch.bool"
else:
logger.warning(f"Invalid data format: {df} for constant/parameter '{name}', defaulting to float32")
return "torch.float32"
Expand Down
2 changes: 1 addition & 1 deletion forge/forge/tvm_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def populate_cast_args(graph, nid, compiler_cfg):
node = graph["nodes"][nid]
args = []
dtype = node["attrs"]["dtype"][0][0]
args.append(("dtype", "torch." + f"{dtype}"))
args.append(("dtype", pytorch_df_str_from_str(dtype, node["forge_name"])))
return args


Expand Down
38 changes: 38 additions & 0 deletions forge/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,3 +1663,41 @@ def forward(self, inp_1):
co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
assert all([compare_with_golden(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)])

pytest.mark.xfail(
reason="Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first"
)
@pytest.mark.parametrize(
"params",
[
([(1, 256, 24, 24), (1, 256, 24, 24)], -4),
([(5, 64, 128, 128), (5, 64, 128, 128)], -3),
([(1, 30, 30, 16), (1, 30, 30, 16)], -2),
([(1, 30, 30, 16), (1, 30, 30, 16)], 3),
([(5, 64, 128, 128), (5, 64, 128, 128)], -1),
([(1, 256, 24, 24), (1, 256, 24, 24)], 4),
([(1, 256, 24, 24), (1, 256, 24, 24)], 2),
([(5, 64, 128, 128), (5, 64, 128, 128)], 1),
([(1, 30, 30, 16), (1, 30, 30, 16)], 0),
],
)
def test_stack(params):
class Stack(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, *tensors):
return torch.stack(tensors, dim=self.dim)

input_shapes, dim = params
inputs = [torch.rand(shape) for shape in input_shapes]

framework_model = Stack(dim)
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name="stack_sanity")
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def test_codegen(test_device, variant):
# Configurations
compiler_cfg = forge.config._get_global_compiler_config()
compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE
compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH

# Load model (with tokenizer)
tokenizer = download_model(AutoTokenizer.from_pretrained, variant)
Expand Down

0 comments on commit 62e0def

Please sign in to comment.