-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tests] Added torch compile checks for functional ops tests #8092
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8092
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 9cc8a54 with merge base 15c166a (): BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…13023) Usage of `from pkg_resources import packaging` leads to a deprecation warning: ``` DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html ``` and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092 Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency: https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19 Pull Request resolved: #113023 Approved by: https://github.com/Skylion007
…torch#113023) Usage of `from pkg_resources import packaging` leads to a deprecation warning: ``` DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html ``` and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092 Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency: https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19 Pull Request resolved: pytorch#113023 Approved by: https://github.com/Skylion007
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @vfdev-5 - do we have an idea of the duration of this new check? When trying locally, the compilation was taking a fair bit of a time, I wonder if this will make the test significantly slower or not
test/test_transforms_v2.py
Outdated
def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs): | ||
def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs): | ||
"""Checks if the functional can be torch compiled and the compiled version can be called without error.""" | ||
if not isinstance(input, tv_tensors.Image): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just check for type(input) == Tensor
directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just replicated the same check from _check_functional_scripted_smoke
, however I agree that we can test all subclasses
test/test_transforms_v2.py
Outdated
assert explanation.graph_count in (1, 2) | ||
assert explanation.graph_break_count in (0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these 2 asserts are redundant, we can probably just assert the break counts?
Also, could we assert the exact number of graph breaks, instead of checking that it's 0 or 1 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can just have one assert assert explanation.graph_break_count in (0, 1)
.
Why assert explanation.graph_break_count in (0, 1)
and not assert explanation.graph_break_count == 1
, because there are ops with configs such to_dtype
on f32 and dtype=torch.float32, scale=False
that there is no graph to build and explanation object reports explanation.graph_break_count == 0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that we could have this mapping kernel -> expected_break_count
so that we can more easily track which ones need to be addressed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case I have to combine somehow check_torch_compile_smoke
flag to enable/disable torch compile tests and expected_break_count
. For example, replace boolean check_torch_compile_smoke
by torch_compile_expected_break_count: Optional[int]
(if None, tests are disabled) ?
Annotated exceptions with encountered errors
73fd62f
to
d558440
Compare
d558440
to
d185656
Compare
if not isinstance(input, torch.Tensor): | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought of doing type(input) == Tensor
instead of isinstance
, so that we don't duplicate this check for subclasses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to be able to compile functional ops for subclasses as well. Right now dynamo does not fully support subclasses and in our case it can compile and even exec correctly if we just skip the last wrap to subclass.
# TODO: Fix this. We skip this method as it leads to | ||
# RecursionError: maximum recursion depth exceeded while calling a Python object | ||
# Keeping it here, leads to graph breaks between multiple functional ops instead of having a single graph | ||
@torch.compiler.disable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, the issue only appears for subclasses but not for pure tensors?
Was this creating a graph break for pure tensors? Because if not, it might be best to ignore that altogether for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, on subclasses when we try to convert result tensor back into its subclass.
Was this creating a graph break for pure tensors?
This method is not called for pure tensors.
Because if not, it might be best to ignore that altogether for now?
With this added we can torch compile functional ops with subclasses vs only image processing functional ops on tensors:
from torchvision.transforms.v2.functional import horizontal_flip
from torchvision.tv_tensors import Image, BoundingBoxes
t_img = torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8)
img = Image(t_img)
box = BoundingBoxes(torch.randint(0, 256, size=(5, 4), dtype=torch.uint8))
cfn = torch.compile(horizontal_flip)
o1 = cfn(t_img)
o2 = cfn(img)
o3 = cfn(box)
vs
from torchvision.transforms.v2.functional import horizontal_flip
from torchvision.tv_tensors import Image, BoundingBoxes
t_img = torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8)
img = Image(t_img)
box = BoundingBoxes(torch.randint(0, 256, size=(5, 4), dtype=torch.uint8))
cfn = torch.compile(horizontal_flip)
o1 = cfn(t_img) # OK
o2 = cfn(img) # RecursionError: maximum recursion depth exceeded while calling a Python object
o3 = cfn(box) # RecursionError: maximum recursion depth exceeded while calling a Python object
I think this is valuable.
The only drawback is that we can't make a single graph for a pipeline of ops:
from torchvision.transforms.v2.functional import horizontal_flip, vertical_flip
from torchvision.tv_tensors import Image
def func(x):
x = horizontal_flip(x)
x = vertical_flip(x)
return x
cfn = torch.compile(func)
x = Image(torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8))
out = cfn(x) # produces 4 graphs
[aot_autograd.py:2047 INFO] TRACED GRAPH
===== Forward graph 0 =====
<eval_with_key>.4 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
def forward(self):
return ()
[aot_autograd.py:2047 INFO] TRACED GRAPH
===== Forward graph 4 =====
<eval_with_key>.37 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
def forward(self, arg0_1: u8[3, 46, 52]):
# File: /vision/torchvision/transforms/v2/functional/_geometry.py:55, code: return image.flip(-1)
rev: u8[3, 46, 52] = torch.ops.prims.rev.default(arg0_1, [2]); arg0_1 = None
return (rev,)
[aot_autograd.py:2047 INFO] TRACED GRAPH
===== Forward graph 5 =====
<eval_with_key>.47 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
def forward(self):
return ()
[aot_autograd.py:2047 INFO] TRACED GRAPH
===== Forward graph 6 =====
<eval_with_key>.52 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
def forward(self, arg0_1: u8[3, 46, 52]):
# File: /vision/torchvision/transforms/v2/functional/_geometry.py:112, code: return image.flip(-2)
rev: u8[3, 46, 52] = torch.ops.prims.rev.default(arg0_1, [1]); arg0_1 = None
return (rev,)
vs just a failure.
Hope it is more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation!
Should we... Merge that fix right now separately from the tests? I agree this valuable to have to avoid failure. But I'm a bit more concerned about adding the tests (at least for now) because:
- they add a lot of time (1min for testing tensors only, 4mins if we test all subclasses)
- checking that 1 or 0 graph break happens doesn't add a ton of value over what we already know TBH. It'd be more valuable if we could assert there are 0 graph breaks so that our test ensures there's never a regression (right now, it won't catch a functional that goes from 0 to 1 graph break).
explain
has proven to be unreliable in counting graph breaks e.g. in Removing graph breaks in transforms #8056 (comment) (although IDK what else we could be using).
So I'm tempted to merge the fix but not merge the tests for now, and keep this PR open to keep track of the progress on the graph breaks. Then when we're in a better spot (e.g. no graph break across all functionals for tensors) maybe we can reconsider? WDYT @vfdev-5 @pmeier ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 on merging the fix.
I don't really understand why we need to assert 0 or 1 graph breaks though. Can't we assert that kernels have 0 graph breaks and functionals 1 (until #8056 (comment)) and xfail / disable the ones that currently don't pass these checks?
And since explain
has its own set of issues, maybe we just run the test on kernels with full_graph=True
and ignore the functionals for now? That way we circumvent the explain
issue. Plus, we don't lose much value in doing so, since our functionals for the most part are just boilerplate the same. Meaning, if the specific kernel works and functionals work in general, the specific functional should work as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let's add @torch.compiler.disable
in another PR. => #8110
In the tests we can disable checking on subclasses and run checks on tensor images input only to reduce CI time.
right now, it won't catch a functional that goes from 0 to 1 graph break
Right now if kernel is not empty we always have 1 break. So, if non-empty kernel starts creating 2 graphs we'll see this. What we can do now is to set it to 1 and update kwargs on tests which produce empty kernels and thus 1 empty graph.
So I'm tempted to merge the fix but not merge the tests for now, and keep this PR open to keep track of the progress on the graph breaks.
Is it possible to go as merging this PR (in two parts) and see the CI fail once we got rid of dict look-up graph break ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would CI fail when we get rid of the dict look-up graph break? If we're asserting that there is 0 or 1 graph break, tests will always be green (and thus not very useful)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can set expected number of graph breaks to 1 and update functional op kwargs such that we never generate an empty graph, so that it will be broken by dict look-up.
return | ||
|
||
functional_compiled = torch.compile(functional) | ||
functional_compiled(input, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a consistency check here?
out = functional_compiled(input, *args, **kwargs)
+ expected = functional(input, *args, **kwargs)
+ torch.testing.assert_close(out, expected)
…13023) Usage of `from pkg_resources import packaging` leads to a deprecation warning: ``` DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html ``` and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092 Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency: https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19 Pull Request resolved: #113023 Approved by: https://github.com/Skylion007, https://github.com/malfet
…torch#113023) Usage of `from pkg_resources import packaging` leads to a deprecation warning: ``` DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html ``` and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092 Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency: https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19 Pull Request resolved: pytorch#113023 Approved by: https://github.com/Skylion007, https://github.com/malfet
…torch#113023) Usage of `from pkg_resources import packaging` leads to a deprecation warning: ``` DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html ``` and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092 Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency: https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19 Pull Request resolved: pytorch#113023 Approved by: https://github.com/Skylion007
…torch#113023) Usage of `from pkg_resources import packaging` leads to a deprecation warning: ``` DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html ``` and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092 Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency: https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19 Pull Request resolved: pytorch#113023 Approved by: https://github.com/Skylion007, https://github.com/malfet
…torch#113023) Usage of `from pkg_resources import packaging` leads to a deprecation warning: ``` DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html ``` and in strict tests where warnings are errors, this leads to CI breaks, e.g.: pytorch/vision#8092 Replacing `pkg_resources.package` with `package` as it is now a pytorch dependency: https://github.com/pytorch/pytorch/blob/fa9045a8725214c05ae4dcec5a855820b861155e/requirements.txt#L19 Pull Request resolved: pytorch#113023 Approved by: https://github.com/Skylion007, https://github.com/malfet
Related to #8056
Description: