Skip to content

Commit

Permalink
Revert "[fx] change from #users to num_users in graph printout (pytor…
Browse files Browse the repository at this point in the history
…ch#101140)"

This reverts commit e568c5a.

Reverted pytorch#101140 on behalf of https://github.com/jeanschmidt due to There are internal changes to this commit that are preventing landing, so I am reverting to unblock the diff train ([comment](pytorch#101140 (comment)))
  • Loading branch information
pytorchmergebot committed May 15, 2023
1 parent 616208b commit 66eef31
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 30 deletions.
6 changes: 3 additions & 3 deletions docs/source/fx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,9 @@ examine our traced module:
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
%x : [#users=1] = placeholder[target=x]
%y : [#users=1] = placeholder[target=y]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""

Expand Down
10 changes: 5 additions & 5 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,11 +396,11 @@ def test_qnnpack_quantizer_conv_linear(self):
"""
This test fails because linear decompositon changes due to the presence of
permute node. In the below linear 1 is decomposed as
%t_default : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_default,), kwargs = {memory_format: torch.contiguous_format}) # noqa: B950
%_unsafe_view_default : [num_users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%clone_default, [8, 16]), kwargs = {}) # noqa: B950
%mm_default : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%_unsafe_view_default, %t_default), kwargs = {}) # noqa: B950
%view_default : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%mm_default, [2, 2, 2, 8]), kwargs = {}) # noqa: B950
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
%clone_default : [#users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_default,), kwargs = {memory_format: torch.contiguous_format}) # noqa: B950
%_unsafe_view_default : [#users=1] = call_function[target=torch.ops.aten._unsafe_view.default](args = (%clone_default, [8, 16]), kwargs = {}) # noqa: B950
%mm_default : [#users=1] = call_function[target=torch.ops.aten.mm.default](args = (%_unsafe_view_default, %t_default), kwargs = {}) # noqa: B950
%view_default : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%mm_default, [2, 2, 2, 8]), kwargs = {}) # noqa: B950
Note the presence of cline and unsafe_view. This is due to permute
"""
Expand Down
2 changes: 1 addition & 1 deletion test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ def forward(self, x):
traced = symbolic_trace(st)
traced.graph.lint()
stringed = str(traced.graph)
for s in ['args', 'kwargs', 'num_users']:
for s in ['args', 'kwargs', '#users']:
assert s in stringed

def test_custom_proxy_type(self):
Expand Down
10 changes: 5 additions & 5 deletions torch/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def forward(self, x):
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
Expand Down
8 changes: 4 additions & 4 deletions torch/fx/experimental/const_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,13 @@ def mod_partition(node: torch.fx.Node):
# because we are fetching attributes directly from the root module, instead of
# fetching them from const_gm. Example: The const_gm must have some format like:
# graph():
# %inp : [num_users=1] = placeholder[target=const_inp]
# %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {})
# %inp : [#users=1] = placeholder[target=const_inp]
# %add : [#users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {})
# return add
# We replace that with the following, which does not have any placeholders:
# graph():
# %inp_1 : [num_users=1] = get_attr[target=const_inp]
# %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
# %inp_1 : [#users=1] = get_attr[target=const_inp]
# %add : [#users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
# return add
root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
for node in root_const_gm.graph.nodes:
Expand Down
12 changes: 6 additions & 6 deletions torch/fx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,12 +700,12 @@ def forward(self, x):
.. code-block:: text
graph(x):
%linear_weight : [num_users=1] = self.linear.weight
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
%linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
%relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
%topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
%linear_weight : [#users=1] = self.linear.weight
%add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
%linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
%relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
%sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
%topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
return topk_1
For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
Expand Down
6 changes: 3 additions & 3 deletions torch/fx/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,18 +472,18 @@ def format_node(self,
return None
maybe_typename = f'{_type_repr(self.type)} ' if self.type else ''
default_val = '(default=' + str(self.args[0]) + ')' if self.args else ''
return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}'
return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = {self.op}[target={self.target}]{default_val}'
elif self.op == 'get_attr':
maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = ' \
f'{self.op}[target={self._pretty_print_target(self.target)}]'
elif self.op == 'output':
if self.type and maybe_return_typename:
maybe_return_typename[0] = f' -> {_type_repr(self.type)}'
return f'return {self.args[0]}'
else:
maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else ''
return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \
return f'%{self.name} : {maybe_typename}[#users={len(self.users)}] = ' \
f'{self.op}[target={self._pretty_print_target(self.target)}](' \
f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})'

Expand Down
6 changes: 3 additions & 3 deletions torch/fx/passes/reinplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def run_node(self, node: Node):
# For multi-output views, we want to map each output view to the base,
# but this mapping involves two separate nodes in FX IR.
# e.g. "a, b = x_1.split(...)" becomes:
# %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
# %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
# %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
# %split_tensor : [#users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {})
# %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {})
# And we'd like to set:
# getitem1.meta['view_of'] = x_1
elif node.target is _operator.getitem:
Expand Down

0 comments on commit 66eef31

Please sign in to comment.