Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests] Added torch compile checks for functional ops tests #8092

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Nov 2, 2023

Related to #8056

Description:

  • Added torch compile checks for functional op tests except few problematic ops

Copy link

pytorch-bot bot commented Nov 2, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8092

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9cc8a54 with merge base 15c166a (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@vfdev-5 vfdev-5 requested a review from NicolasHug November 6, 2023 13:37
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Nov 6, 2023
…13023)

Usage of `from pkg_resources import packaging` leads to a deprecation warning:
```
DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
```
and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092

Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency:
https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19
Pull Request resolved: #113023
Approved by: https://github.com/Skylion007
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…torch#113023)

Usage of `from pkg_resources import packaging` leads to a deprecation warning:
```
DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
```
and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092

Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency:
https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19
Pull Request resolved: pytorch#113023
Approved by: https://github.com/Skylion007
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @vfdev-5 - do we have an idea of the duration of this new check? When trying locally, the compilation was taking a fair bit of a time, I wonder if this will make the test significantly slower or not

def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs):
def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs):
"""Checks if the functional can be torch compiled and the compiled version can be called without error."""
if not isinstance(input, tv_tensors.Image):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just check for type(input) == Tensor directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just replicated the same check from _check_functional_scripted_smoke, however I agree that we can test all subclasses

Comment on lines 200 to 201
assert explanation.graph_count in (1, 2)
assert explanation.graph_break_count in (0, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these 2 asserts are redundant, we can probably just assert the break counts?
Also, could we assert the exact number of graph breaks, instead of checking that it's 0 or 1 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can just have one assert assert explanation.graph_break_count in (0, 1).
Why assert explanation.graph_break_count in (0, 1) and not assert explanation.graph_break_count == 1, because there are ops with configs such to_dtype on f32 and dtype=torch.float32, scale=False that there is no graph to build and explanation object reports explanation.graph_break_count == 0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that we could have this mapping kernel -> expected_break_count so that we can more easily track which ones need to be addressed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case I have to combine somehow check_torch_compile_smoke flag to enable/disable torch compile tests and expected_break_count. For example, replace boolean check_torch_compile_smoke by torch_compile_expected_break_count: Optional[int] (if None, tests are disabled) ?

Annotated exceptions with encountered errors
@vfdev-5 vfdev-5 force-pushed the enable-torch-compile-equalize branch 2 times, most recently from 73fd62f to d558440 Compare November 8, 2023 10:31
@vfdev-5 vfdev-5 force-pushed the enable-torch-compile-equalize branch from d558440 to d185656 Compare November 8, 2023 10:38
Comment on lines +192 to +193
if not isinstance(input, torch.Tensor):
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought of doing type(input) == Tensor instead of isinstance, so that we don't duplicate this check for subclasses.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to be able to compile functional ops for subclasses as well. Right now dynamo does not fully support subclasses and in our case it can compile and even exec correctly if we just skip the last wrap to subclass.

# TODO: Fix this. We skip this method as it leads to
# RecursionError: maximum recursion depth exceeded while calling a Python object
# Keeping it here, leads to graph breaks between multiple functional ops instead of having a single graph
@torch.compiler.disable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, the issue only appears for subclasses but not for pure tensors?
Was this creating a graph break for pure tensors? Because if not, it might be best to ignore that altogether for now?

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, on subclasses when we try to convert result tensor back into its subclass.

Was this creating a graph break for pure tensors?

This method is not called for pure tensors.

Because if not, it might be best to ignore that altogether for now?

With this added we can torch compile functional ops with subclasses vs only image processing functional ops on tensors:

from torchvision.transforms.v2.functional import horizontal_flip
from torchvision.tv_tensors import Image, BoundingBoxes

t_img = torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8)
img = Image(t_img)
box = BoundingBoxes(torch.randint(0, 256, size=(5, 4), dtype=torch.uint8))

cfn = torch.compile(horizontal_flip)
o1 = cfn(t_img)
o2 = cfn(img)
o3 = cfn(box)

vs

from torchvision.transforms.v2.functional import horizontal_flip
from torchvision.tv_tensors import Image, BoundingBoxes

t_img = torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8)
img = Image(t_img)
box = BoundingBoxes(torch.randint(0, 256, size=(5, 4), dtype=torch.uint8))

cfn = torch.compile(horizontal_flip)
o1 = cfn(t_img)  # OK
o2 = cfn(img)    # RecursionError: maximum recursion depth exceeded while calling a Python object
o3 = cfn(box)    # RecursionError: maximum recursion depth exceeded while calling a Python object

I think this is valuable.

The only drawback is that we can't make a single graph for a pipeline of ops:

from torchvision.transforms.v2.functional import horizontal_flip, vertical_flip
from torchvision.tv_tensors import Image

def func(x):
    x = horizontal_flip(x)
    x = vertical_flip(x)
    return x

cfn = torch.compile(func)

x = Image(torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8))
out = cfn(x)  # produces 4 graphs

[aot_autograd.py:2047 INFO] TRACED GRAPH
 ===== Forward graph 0 =====
 <eval_with_key>.4 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
    def forward(self):
        return ()

[aot_autograd.py:2047 INFO] TRACED GRAPH
 ===== Forward graph 4 =====
 <eval_with_key>.37 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: u8[3, 46, 52]):
        # File: /vision/torchvision/transforms/v2/functional/_geometry.py:55, code: return image.flip(-1)
        rev: u8[3, 46, 52] = torch.ops.prims.rev.default(arg0_1, [2]);  arg0_1 = None
        return (rev,)
        
[aot_autograd.py:2047 INFO] TRACED GRAPH
 ===== Forward graph 5 =====
 <eval_with_key>.47 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
    def forward(self):
        return ()
        
[aot_autograd.py:2047 INFO] TRACED GRAPH
 ===== Forward graph 6 =====
 <eval_with_key>.52 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: u8[3, 46, 52]):
        # File: /vision/torchvision/transforms/v2/functional/_geometry.py:112, code: return image.flip(-2)
        rev: u8[3, 46, 52] = torch.ops.prims.rev.default(arg0_1, [1]);  arg0_1 = None
        return (rev,)

vs just a failure.

Hope it is more clear.

Copy link
Member

@NicolasHug NicolasHug Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation!

Should we... Merge that fix right now separately from the tests? I agree this valuable to have to avoid failure. But I'm a bit more concerned about adding the tests (at least for now) because:

  • they add a lot of time (1min for testing tensors only, 4mins if we test all subclasses)
  • checking that 1 or 0 graph break happens doesn't add a ton of value over what we already know TBH. It'd be more valuable if we could assert there are 0 graph breaks so that our test ensures there's never a regression (right now, it won't catch a functional that goes from 0 to 1 graph break).
  • explain has proven to be unreliable in counting graph breaks e.g. in Removing graph breaks in transforms #8056 (comment) (although IDK what else we could be using).

So I'm tempted to merge the fix but not merge the tests for now, and keep this PR open to keep track of the progress on the graph breaks. Then when we're in a better spot (e.g. no graph break across all functionals for tensors) maybe we can reconsider? WDYT @vfdev-5 @pmeier ?

Copy link
Collaborator

@pmeier pmeier Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 on merging the fix.

I don't really understand why we need to assert 0 or 1 graph breaks though. Can't we assert that kernels have 0 graph breaks and functionals 1 (until #8056 (comment)) and xfail / disable the ones that currently don't pass these checks?

And since explain has its own set of issues, maybe we just run the test on kernels with full_graph=True and ignore the functionals for now? That way we circumvent the explain issue. Plus, we don't lose much value in doing so, since our functionals for the most part are just boilerplate the same. Meaning, if the specific kernel works and functionals work in general, the specific functional should work as well.

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's add @torch.compiler.disable in another PR. => #8110

In the tests we can disable checking on subclasses and run checks on tensor images input only to reduce CI time.

right now, it won't catch a functional that goes from 0 to 1 graph break

Right now if kernel is not empty we always have 1 break. So, if non-empty kernel starts creating 2 graphs we'll see this. What we can do now is to set it to 1 and update kwargs on tests which produce empty kernels and thus 1 empty graph.

So I'm tempted to merge the fix but not merge the tests for now, and keep this PR open to keep track of the progress on the graph breaks.

Is it possible to go as merging this PR (in two parts) and see the CI fail once we got rid of dict look-up graph break ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would CI fail when we get rid of the dict look-up graph break? If we're asserting that there is 0 or 1 graph break, tests will always be green (and thus not very useful)?

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set expected number of graph breaks to 1 and update functional op kwargs such that we never generate an empty graph, so that it will be broken by dict look-up.

return

functional_compiled = torch.compile(functional)
functional_compiled(input, *args, **kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a consistency check here?

     out = functional_compiled(input, *args, **kwargs)
+    expected = functional(input, *args, **kwargs)
+    torch.testing.assert_close(out, expected)

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Nov 10, 2023
…13023)

Usage of `from pkg_resources import packaging` leads to a deprecation warning:
```
DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
```
and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092

Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency:
https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19
Pull Request resolved: #113023
Approved by: https://github.com/Skylion007, https://github.com/malfet
pytorchmergebot pushed a commit to zabboud/pytorch that referenced this pull request Nov 10, 2023
…torch#113023)

Usage of `from pkg_resources import packaging` leads to a deprecation warning:
```
DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
```
and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092

Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency:
https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19
Pull Request resolved: pytorch#113023
Approved by: https://github.com/Skylion007, https://github.com/malfet
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…torch#113023)

Usage of `from pkg_resources import packaging` leads to a deprecation warning:
```
DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
```
and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092

Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency:
https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19
Pull Request resolved: pytorch#113023
Approved by: https://github.com/Skylion007
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…torch#113023)

Usage of `from pkg_resources import packaging` leads to a deprecation warning:
```
DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
```
and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092

Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency:
https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19
Pull Request resolved: pytorch#113023
Approved by: https://github.com/Skylion007, https://github.com/malfet
jeffdaily pushed a commit to ROCm/pytorch that referenced this pull request May 29, 2024
…torch#113023)

Usage of `from pkg_resources import packaging` leads to a deprecation warning:
```
DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
```
and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092

Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency:
https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19
Pull Request resolved: pytorch#113023
Approved by: https://github.com/Skylion007, https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants