-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tests] Added torch compile checks for functional ops tests #8092
base: main
Are you sure you want to change the base?
Changes from all commits
d4f16cb
d5ca7b2
9a120fa
6fce5da
829a2d3
7f91053
045465e
8410021
d185656
3c959d9
9cc8a54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
import pickle | ||
import random | ||
import re | ||
import sys | ||
from copy import deepcopy | ||
from pathlib import Path | ||
from unittest import mock | ||
|
@@ -186,7 +187,20 @@ def _check_functional_scripted_smoke(functional, input, *args, **kwargs): | |
functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) | ||
|
||
|
||
def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs): | ||
def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs): | ||
"""Checks if the functional can be torch compiled and the compiled version can be called without error.""" | ||
if not isinstance(input, torch.Tensor): | ||
return | ||
|
||
functional_compiled = torch.compile(functional) | ||
functional_compiled(input, *args, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a consistency check here? out = functional_compiled(input, *args, **kwargs)
+ expected = functional(input, *args, **kwargs)
+ torch.testing.assert_close(out, expected) |
||
|
||
explanation = torch._dynamo.explain(functional_compiled)(input, *args, **kwargs) | ||
# TODO: Set expected value to 0 once fixed the graph break related to function registration | ||
assert explanation.graph_break_count in (0, 1) | ||
|
||
|
||
def check_functional(functional, input, *args, check_scripted_smoke=True, check_torch_compile_smoke=True, **kwargs): | ||
unknown_input = object() | ||
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): | ||
functional(unknown_input, *args, **kwargs) | ||
|
@@ -204,6 +218,16 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar | |
if check_scripted_smoke: | ||
_check_functional_scripted_smoke(functional, input, *args, **kwargs) | ||
|
||
# Skip check on Windows as torch.compile does not work on Win32 | ||
if check_torch_compile_smoke and sys.platform != "win32": | ||
# Temporary fix to catch deprectation warning | ||
# This can be removed once https://github.com/pytorch/pytorch/pull/113023 is merged: | ||
import warnings | ||
|
||
with warnings.catch_warnings(): | ||
warnings.filterwarnings("ignore", category=DeprecationWarning) | ||
_check_functional_torch_compile_smoke(functional, input, *args, **kwargs) | ||
|
||
|
||
def check_functional_kernel_signature_match(functional, *, kernel, input_type): | ||
"""Checks if the signature of the functional matches the kernel signature.""" | ||
|
@@ -656,6 +680,7 @@ def test_functional(self, size, make_input): | |
size=size, | ||
antialias=True, | ||
check_scripted_smoke=not isinstance(size, int), | ||
check_torch_compile_smoke=False, | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
|
@@ -3469,7 +3494,12 @@ def test_kernel(self, kernel, make_input): | |
) | ||
def test_functional(self, make_input): | ||
check_functional( | ||
F.resized_crop, make_input(self.INPUT_SIZE), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, antialias=True | ||
F.resized_crop, | ||
make_input(self.INPUT_SIZE), | ||
**self.CROP_KWARGS, | ||
size=self.OUTPUT_SIZE, | ||
antialias=True, | ||
check_torch_compile_smoke=False, | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
|
@@ -3949,7 +3979,7 @@ def test_kernel_video(self): | |
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], | ||
) | ||
def test_functional(self, make_input): | ||
check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS) | ||
check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS, check_torch_compile_smoke=False) | ||
|
||
@pytest.mark.parametrize( | ||
("kernel", "input_type"), | ||
|
@@ -4106,7 +4136,7 @@ def test_kernel_video(self): | |
|
||
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) | ||
def test_functional(self, make_input): | ||
check_functional(F.equalize, make_input()) | ||
check_functional(F.equalize, make_input(), check_torch_compile_smoke=False) | ||
|
||
@pytest.mark.parametrize( | ||
("kernel", "input_type"), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,10 @@ | |
from ._video import Video | ||
|
||
|
||
# TODO: Fix this. We skip this method as it leads to | ||
# RecursionError: maximum recursion depth exceeded while calling a Python object | ||
# Keeping it here, leads to graph breaks between multiple functional ops instead of having a single graph | ||
@torch.compiler.disable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to confirm, the issue only appears for subclasses but not for pure tensors? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, on subclasses when we try to convert result tensor back into its subclass.
This method is not called for pure tensors.
With this added we can torch compile functional ops with subclasses vs only image processing functional ops on tensors: from torchvision.transforms.v2.functional import horizontal_flip
from torchvision.tv_tensors import Image, BoundingBoxes
t_img = torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8)
img = Image(t_img)
box = BoundingBoxes(torch.randint(0, 256, size=(5, 4), dtype=torch.uint8))
cfn = torch.compile(horizontal_flip)
o1 = cfn(t_img)
o2 = cfn(img)
o3 = cfn(box) vs from torchvision.transforms.v2.functional import horizontal_flip
from torchvision.tv_tensors import Image, BoundingBoxes
t_img = torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8)
img = Image(t_img)
box = BoundingBoxes(torch.randint(0, 256, size=(5, 4), dtype=torch.uint8))
cfn = torch.compile(horizontal_flip)
o1 = cfn(t_img) # OK
o2 = cfn(img) # RecursionError: maximum recursion depth exceeded while calling a Python object
o3 = cfn(box) # RecursionError: maximum recursion depth exceeded while calling a Python object I think this is valuable. The only drawback is that we can't make a single graph for a pipeline of ops: from torchvision.transforms.v2.functional import horizontal_flip, vertical_flip
from torchvision.tv_tensors import Image
def func(x):
x = horizontal_flip(x)
x = vertical_flip(x)
return x
cfn = torch.compile(func)
x = Image(torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8))
out = cfn(x) # produces 4 graphs
[aot_autograd.py:2047 INFO] TRACED GRAPH
===== Forward graph 0 =====
<eval_with_key>.4 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
def forward(self):
return ()
[aot_autograd.py:2047 INFO] TRACED GRAPH
===== Forward graph 4 =====
<eval_with_key>.37 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
def forward(self, arg0_1: u8[3, 46, 52]):
# File: /vision/torchvision/transforms/v2/functional/_geometry.py:55, code: return image.flip(-1)
rev: u8[3, 46, 52] = torch.ops.prims.rev.default(arg0_1, [2]); arg0_1 = None
return (rev,)
[aot_autograd.py:2047 INFO] TRACED GRAPH
===== Forward graph 5 =====
<eval_with_key>.47 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
def forward(self):
return ()
[aot_autograd.py:2047 INFO] TRACED GRAPH
===== Forward graph 6 =====
<eval_with_key>.52 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
def forward(self, arg0_1: u8[3, 46, 52]):
# File: /vision/torchvision/transforms/v2/functional/_geometry.py:112, code: return image.flip(-2)
rev: u8[3, 46, 52] = torch.ops.prims.rev.default(arg0_1, [1]); arg0_1 = None
return (rev,) vs just a failure. Hope it is more clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation! Should we... Merge that fix right now separately from the tests? I agree this valuable to have to avoid failure. But I'm a bit more concerned about adding the tests (at least for now) because:
So I'm tempted to merge the fix but not merge the tests for now, and keep this PR open to keep track of the progress on the graph breaks. Then when we're in a better spot (e.g. no graph break across all functionals for tensors) maybe we can reconsider? WDYT @vfdev-5 @pmeier ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 on merging the fix. I don't really understand why we need to assert 0 or 1 graph breaks though. Can't we assert that kernels have 0 graph breaks and functionals 1 (until #8056 (comment)) and xfail / disable the ones that currently don't pass these checks? And since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, let's add In the tests we can disable checking on subclasses and run checks on tensor images input only to reduce CI time.
Right now if kernel is not empty we always have 1 break. So, if non-empty kernel starts creating 2 graphs we'll see this. What we can do now is to set it to 1 and update kwargs on tests which produce empty kernels and thus 1 empty graph.
Is it possible to go as merging this PR (in two parts) and see the CI fail once we got rid of dict look-up graph break ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would CI fail when we get rid of the dict look-up graph break? If we're asserting that there is 0 or 1 graph break, tests will always be green (and thus not very useful)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can set expected number of graph breaks to 1 and update functional op kwargs such that we never generate an empty graph, so that it will be broken by dict look-up. |
||
def wrap(wrappee, *, like, **kwargs): | ||
"""[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought of doing
type(input) == Tensor
instead ofisinstance
, so that we don't duplicate this check for subclasses.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to be able to compile functional ops for subclasses as well. Right now dynamo does not fully support subclasses and in our case it can compile and even exec correctly if we just skip the last wrap to subclass.