diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4f8d0027bd6..9cba49440d1 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -7,6 +7,7 @@ import pickle import random import re +import sys from copy import deepcopy from pathlib import Path from unittest import mock @@ -186,7 +187,20 @@ def _check_functional_scripted_smoke(functional, input, *args, **kwargs): functional_scripted(input.as_subclass(torch.Tensor), *args, **kwargs) -def check_functional(functional, input, *args, check_scripted_smoke=True, **kwargs): +def _check_functional_torch_compile_smoke(functional, input, *args, **kwargs): + """Checks if the functional can be torch compiled and the compiled version can be called without error.""" + if not isinstance(input, torch.Tensor): + return + + functional_compiled = torch.compile(functional) + functional_compiled(input, *args, **kwargs) + + explanation = torch._dynamo.explain(functional_compiled)(input, *args, **kwargs) + # TODO: Set expected value to 0 once fixed the graph break related to function registration + assert explanation.graph_break_count in (0, 1) + + +def check_functional(functional, input, *args, check_scripted_smoke=True, check_torch_compile_smoke=True, **kwargs): unknown_input = object() with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))): functional(unknown_input, *args, **kwargs) @@ -204,6 +218,16 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar if check_scripted_smoke: _check_functional_scripted_smoke(functional, input, *args, **kwargs) + # Skip check on Windows as torch.compile does not work on Win32 + if check_torch_compile_smoke and sys.platform != "win32": + # Temporary fix to catch deprectation warning + # This can be removed once https://github.com/pytorch/pytorch/pull/113023 is merged: + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + _check_functional_torch_compile_smoke(functional, input, *args, **kwargs) + def check_functional_kernel_signature_match(functional, *, kernel, input_type): """Checks if the signature of the functional matches the kernel signature.""" @@ -656,6 +680,7 @@ def test_functional(self, size, make_input): size=size, antialias=True, check_scripted_smoke=not isinstance(size, int), + check_torch_compile_smoke=False, ) @pytest.mark.parametrize( @@ -3469,7 +3494,12 @@ def test_kernel(self, kernel, make_input): ) def test_functional(self, make_input): check_functional( - F.resized_crop, make_input(self.INPUT_SIZE), **self.CROP_KWARGS, size=self.OUTPUT_SIZE, antialias=True + F.resized_crop, + make_input(self.INPUT_SIZE), + **self.CROP_KWARGS, + size=self.OUTPUT_SIZE, + antialias=True, + check_torch_compile_smoke=False, ) @pytest.mark.parametrize( @@ -3949,7 +3979,7 @@ def test_kernel_video(self): [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], ) def test_functional(self, make_input): - check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS) + check_functional(F.perspective, make_input(), **self.MINIMAL_KWARGS, check_torch_compile_smoke=False) @pytest.mark.parametrize( ("kernel", "input_type"), @@ -4106,7 +4136,7 @@ def test_kernel_video(self): @pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video]) def test_functional(self, make_input): - check_functional(F.equalize, make_input()) + check_functional(F.equalize, make_input(), check_torch_compile_smoke=False) @pytest.mark.parametrize( ("kernel", "input_type"), diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index d55e10e8620..2e58d9d4c6a 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -8,6 +8,10 @@ from ._video import Video +# TODO: Fix this. We skip this method as it leads to +# RecursionError: maximum recursion depth exceeded while calling a Python object +# Keeping it here, leads to graph breaks between multiple functional ops instead of having a single graph +@torch.compiler.disable def wrap(wrappee, *, like, **kwargs): """[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.