Skip to content

Commit

Permalink
[aotinductor] Solves a problem where a tensor is returned more than o…
Browse files Browse the repository at this point in the history
…nce (pytorch#112177)

Pull Request resolved: pytorch#112177
Approved by: https://github.com/zhxchen17
  • Loading branch information
desertfire authored and Skylion007 committed Nov 14, 2023
1 parent fa9ab72 commit 20177c5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
16 changes: 16 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,18 @@ def forward(self, x):
x = torch.randn(5, device=self.device)
self.check_model(Model(self.device), (x,))

def test_repeat_output(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
y = torch.sin(x)
return y, y

example_inputs = (torch.randn(3, 10, device=self.device),)
self.check_model(Model(), example_inputs)


class AOTInductorTestABICompatibleCpu(TestCase):
device = "cpu"
Expand All @@ -1036,6 +1048,8 @@ class AOTInductorTestABICompatibleCpu(TestCase):
"test_freezing": TestFailure(("abi_compatible_cpu",), is_skip=True),
"test_normal_functional": TestFailure(("abi_compatible_cpu",)),
"test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
# There is a double-free issue which will be fixed in another PR
"test_repeat_output": TestFailure(("abi_compatible_cpu",), is_skip=True),
"test_sdpa": TestFailure(("abi_compatible_cpu",)),
"test_sdpa_2": TestFailure(("abi_compatible_cpu",)),
"test_simple_dynamic": TestFailure(("abi_compatible_cpu",)),
Expand All @@ -1058,6 +1072,8 @@ class AOTInductorTestABICompatibleCuda(TestCase):
{
"test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)),
"test_normal_functional": TestFailure(("abi_compatible_cuda",)),
# There is a double-free issue which will be fixed in another PR
"test_repeat_output": TestFailure(("abi_compatible_cuda",), is_skip=True),
},
)

Expand Down
4 changes: 3 additions & 1 deletion torch/_export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict, buf
# Step 2: Find the all the buffers that were mutated and update them
if node.op == "output":
user_output_nodes = []
for return_node in node.all_input_nodes:
# In the case that the same node is returned multiple times,
# node.all_input_nodes will only iterate that node once
for return_node in pytree.tree_flatten(node.args)[0]:
return_node_name = return_node.name
# we found a param/buffer mutation
if return_node_name in buffers_to_mutate:
Expand Down

0 comments on commit 20177c5

Please sign in to comment.