diff --git a/forge/forge/op/eval/forge/eltwise_nary.py b/forge/forge/op/eval/forge/eltwise_nary.py index 684d1d09..4d99aff5 100644 --- a/forge/forge/op/eval/forge/eltwise_nary.py +++ b/forge/forge/op/eval/forge/eltwise_nary.py @@ -144,7 +144,12 @@ def get_eltwise_shape_and_broadcast(): axis = attr[0] output_shape = list(ops[0]) - output_shape.insert(axis, len(ops)) + if axis == -1: + output_shape.append(len(ops)) + elif axis < 0: + output_shape.insert(axis + 1, len(ops)) + else: + output_shape.insert(axis, len(ops)) return output_shape, [] elif type == "interleave": @@ -246,11 +251,16 @@ def decompose(type, attr, dc, inputs): new_inputs = [] for inp in inputs: inp_shape = inp.shape.as_list() - inp_shape.insert(axis, 1) - new_inp = dc.op("reshape", [inp], (*inp_shape,)) + if axis == -1: + inp_shape.append(1) + elif axis < 0: + inp_shape.insert(axis + 1, 1) + else: + inp_shape.insert(axis, 1) + new_inp = dc.op_with_named_attrs("reshape", [inp], {"shape": (*inp_shape,)}, (*inp_shape,)) new_inputs.append(new_inp) - output = dc.op("concatenate", new_inputs, (axis,)) + output = dc.op_with_named_attrs("concatenate", new_inputs, {"dim": (axis)}, (axis,)) dc.fuse(output) if type == "concatenate": diff --git a/forge/forge/tvm_to_python.py b/forge/forge/tvm_to_python.py index e5c28434..e5aa0254 100644 --- a/forge/forge/tvm_to_python.py +++ b/forge/forge/tvm_to_python.py @@ -1239,6 +1239,8 @@ def populate_cast_args(graph, nid, compiler_cfg): node = graph["nodes"][nid] args = [] dtype = node["attrs"]["dtype"][0][0] + if dtype == "uint1": + dtype = "bool" args.append(("dtype", "torch." + f"{dtype}")) return args diff --git a/forge/test/mlir/test_ops.py b/forge/test/mlir/test_ops.py index ab4f2b95..1218c952 100644 --- a/forge/test/mlir/test_ops.py +++ b/forge/test/mlir/test_ops.py @@ -1510,3 +1510,43 @@ def forward(self, x): co_out = [co.to("cpu") for co in co_out] fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out assert all([compare_with_golden(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) + + +@pytest.mark.xfail( + reason="Unable to reshape a tensor in TILE_LAYOUT to non-tile height and width! Please convert the tensor to ROW_MAJOR_LAYOUT first" +) +@pytest.mark.parametrize( + "params", + [ + ([(1, 256, 24, 24), (1, 256, 24, 24)], -4), + ([(5, 64, 128, 128), (5, 64, 128, 128)], -3), + ([(1, 30, 30, 16), (1, 30, 30, 16)], -2), + ([(1, 30, 30, 16), (1, 30, 30, 16)], 3), + ([(5, 64, 128, 128), (5, 64, 128, 128)], -1), + ([(1, 256, 24, 24), (1, 256, 24, 24)], 4), + ([(1, 256, 24, 24), (1, 256, 24, 24)], 2), + ([(5, 64, 128, 128), (5, 64, 128, 128)], 1), + ([(1, 30, 30, 16), (1, 30, 30, 16)], 0), + ], +) +def test_stack(params): + class Stack(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, *tensors): + return torch.stack(tensors, dim=self.dim) + + input_shapes, dim = params + inputs = [torch.rand(shape) for shape in input_shapes] + + framework_model = Stack(dim) + fw_out = framework_model(*inputs) + + compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name="stack_sanity") + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out + assert all([compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]) diff --git a/forge/test/model_demos/high_prio/nlp/pytorch/test_codegen.py b/forge/test/model_demos/high_prio/nlp/pytorch/test_codegen.py index 19f5d135..3580e04d 100644 --- a/forge/test/model_demos/high_prio/nlp/pytorch/test_codegen.py +++ b/forge/test/model_demos/high_prio/nlp/pytorch/test_codegen.py @@ -22,7 +22,7 @@ def test_codegen(test_device, variant): # Configurations compiler_cfg = forge.config._get_global_compiler_config() - compiler_cfg.compile_depth = forge.CompileDepth.INIT_COMPILE + compiler_cfg.compile_depth = forge.CompileDepth.SPLIT_GRAPH # Load model (with tokenizer) tokenizer = download_model(AutoTokenizer.from_pretrained, variant)