Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests] Added torch compile checks for functional ops tests #8092

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
30 changes: 27 additions & 3 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pickle
import random
import re
import sys
from copy import deepcopy
from pathlib import Path
from unittest import mock
Expand Down Expand Up @@ -186,7 +187,20 @@ def _check_functional_scripted_smoke(functional, input, *args, **kwargs):
functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs)


def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs):
def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs):
"""Checks if the functional can be torch compiled and the compiled version can be called without error."""
if not isinstance(input, torch.Tensor):
return
Comment on lines +192 to +193
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought of doing type(input) == Tensor instead of isinstance, so that we don't duplicate this check for subclasses.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to be able to compile functional ops for subclasses as well. Right now dynamo does not fully support subclasses and in our case it can compile and even exec correctly if we just skip the last wrap to subclass.


functional_compiled = torch.compile(functional)
functional_compiled(input, *args, **kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a consistency check here?

     out = functional_compiled(input, *args, **kwargs)
+    expected = functional(input, *args, **kwargs)
+    torch.testing.assert_close(out, expected)


explanation = torch._dynamo.explain(functional_compiled)(input, *args, **kwargs)
# TODO: Set expected value to 0 once fixed the graph break related to function registration
assert explanation.graph_break_count in (0, 1)


def check_functional(functional, input, *args, check_scripted_smoke=True, check_torch_compile_smoke=True, **kwargs):
unknown_input = object()
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
functional(unknown_input, *args, **kwargs)
Expand All @@ -204,6 +218,10 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar
if check_scripted_smoke:
_check_functional_scripted_smoke(functional, input, *args, **kwargs)

# Skip check on Windows as torch.compile does not work on Win32
if check_torch_compile_smoke and sys.platform != "win32":
_check_functional_torch_compile_smoke(functional, input, *args, **kwargs)


def check_functional_kernel_signature_match(functional, *, kernel, input_type):
"""Checks if the signature of the functional matches the kernel signature."""
Expand Down Expand Up @@ -656,6 +674,7 @@ def test_functional(self, size, make_input):
size=size,
antialias=True,
check_scripted_smoke=not isinstance(size, int),
check_torch_compile_smoke=False,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -3469,7 +3488,12 @@ def test_kernel(self, kernel, make_input):
)
def test_functional(self, make_input):
check_functional(
F.resized_crop, make_input(self.INPUT_SIZE), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, antialias=True
F.resized_crop,
make_input(self.INPUT_SIZE),
**self.CROP_KWARGS,
size=self.OUTPUT_SIZE,
antialias=True,
check_torch_compile_smoke=False,
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -3949,7 +3973,7 @@ def test_kernel_video(self):
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
)
def test_functional(self, make_input):
check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS)
check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS, check_torch_compile_smoke=False)

@pytest.mark.parametrize(
("kernel", "input_type"),
Expand Down
4 changes: 4 additions & 0 deletions torchvision/tv_tensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from ._video import Video


# TODO: Fix this. We skip this method as it leads to
# RecursionError: maximum recursion depth exceeded while calling a Python object
# Keeping it here, leads to graph breaks between multiple functional ops instead of having a single graph
@torch.compiler.disable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, the issue only appears for subclasses but not for pure tensors?
Was this creating a graph break for pure tensors? Because if not, it might be best to ignore that altogether for now?

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, on subclasses when we try to convert result tensor back into its subclass.

Was this creating a graph break for pure tensors?

This method is not called for pure tensors.

Because if not, it might be best to ignore that altogether for now?

With this added we can torch compile functional ops with subclasses vs only image processing functional ops on tensors:

from torchvision.transforms.v2.functional import horizontal_flip
from torchvision.tv_tensors import Image, BoundingBoxes

t_img = torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8)
img = Image(t_img)
box = BoundingBoxes(torch.randint(0, 256, size=(5, 4), dtype=torch.uint8))

cfn = torch.compile(horizontal_flip)
o1 = cfn(t_img)
o2 = cfn(img)
o3 = cfn(box)

vs

from torchvision.transforms.v2.functional import horizontal_flip
from torchvision.tv_tensors import Image, BoundingBoxes

t_img = torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8)
img = Image(t_img)
box = BoundingBoxes(torch.randint(0, 256, size=(5, 4), dtype=torch.uint8))

cfn = torch.compile(horizontal_flip)
o1 = cfn(t_img)  # OK
o2 = cfn(img)    # RecursionError: maximum recursion depth exceeded while calling a Python object
o3 = cfn(box)    # RecursionError: maximum recursion depth exceeded while calling a Python object

I think this is valuable.

The only drawback is that we can't make a single graph for a pipeline of ops:

from torchvision.transforms.v2.functional import horizontal_flip, vertical_flip
from torchvision.tv_tensors import Image

def func(x):
    x = horizontal_flip(x)
    x = vertical_flip(x)
    return x

cfn = torch.compile(func)

x = Image(torch.randint(0, 256, size=(3, 46, 52), dtype=torch.uint8))
out = cfn(x)  # produces 4 graphs

[aot_autograd.py:2047 INFO] TRACED GRAPH
 ===== Forward graph 0 =====
 <eval_with_key>.4 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
    def forward(self):
        return ()

[aot_autograd.py:2047 INFO] TRACED GRAPH
 ===== Forward graph 4 =====
 <eval_with_key>.37 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: u8[3, 46, 52]):
        # File: /vision/torchvision/transforms/v2/functional/_geometry.py:55, code: return image.flip(-1)
        rev: u8[3, 46, 52] = torch.ops.prims.rev.default(arg0_1, [2]);  arg0_1 = None
        return (rev,)
        
[aot_autograd.py:2047 INFO] TRACED GRAPH
 ===== Forward graph 5 =====
 <eval_with_key>.47 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
    def forward(self):
        return ()
        
[aot_autograd.py:2047 INFO] TRACED GRAPH
 ===== Forward graph 6 =====
 <eval_with_key>.52 from /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py:509 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: u8[3, 46, 52]):
        # File: /vision/torchvision/transforms/v2/functional/_geometry.py:112, code: return image.flip(-2)
        rev: u8[3, 46, 52] = torch.ops.prims.rev.default(arg0_1, [1]);  arg0_1 = None
        return (rev,)

vs just a failure.

Hope it is more clear.

Copy link
Member

@NicolasHug NicolasHug Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation!

Should we... Merge that fix right now separately from the tests? I agree this valuable to have to avoid failure. But I'm a bit more concerned about adding the tests (at least for now) because:

  • they add a lot of time (1min for testing tensors only, 4mins if we test all subclasses)
  • checking that 1 or 0 graph break happens doesn't add a ton of value over what we already know TBH. It'd be more valuable if we could assert there are 0 graph breaks so that our test ensures there's never a regression (right now, it won't catch a functional that goes from 0 to 1 graph break).
  • explain has proven to be unreliable in counting graph breaks e.g. in Removing graph breaks in transforms #8056 (comment) (although IDK what else we could be using).

So I'm tempted to merge the fix but not merge the tests for now, and keep this PR open to keep track of the progress on the graph breaks. Then when we're in a better spot (e.g. no graph break across all functionals for tensors) maybe we can reconsider? WDYT @vfdev-5 @pmeier ?

Copy link
Collaborator

@pmeier pmeier Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 on merging the fix.

I don't really understand why we need to assert 0 or 1 graph breaks though. Can't we assert that kernels have 0 graph breaks and functionals 1 (until #8056 (comment)) and xfail / disable the ones that currently don't pass these checks?

And since explain has its own set of issues, maybe we just run the test on kernels with full_graph=True and ignore the functionals for now? That way we circumvent the explain issue. Plus, we don't lose much value in doing so, since our functionals for the most part are just boilerplate the same. Meaning, if the specific kernel works and functionals work in general, the specific functional should work as well.

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's add @torch.compiler.disable in another PR. => #8110

In the tests we can disable checking on subclasses and run checks on tensor images input only to reduce CI time.

right now, it won't catch a functional that goes from 0 to 1 graph break

Right now if kernel is not empty we always have 1 break. So, if non-empty kernel starts creating 2 graphs we'll see this. What we can do now is to set it to 1 and update kwargs on tests which produce empty kernels and thus 1 empty graph.

So I'm tempted to merge the fix but not merge the tests for now, and keep this PR open to keep track of the progress on the graph breaks.

Is it possible to go as merging this PR (in two parts) and see the CI fail once we got rid of dict look-up graph break ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would CI fail when we get rid of the dict look-up graph break? If we're asserting that there is 0 or 1 graph break, tests will always be green (and thus not very useful)?

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set expected number of graph breaks to 1 and update functional op kwargs such that we never generate an empty graph, so that it will be broken by dict look-up.

def wrap(wrappee, *, like, **kwargs):
"""[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.

Expand Down
Loading