diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index adb1256303f..24e7aa97986 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -102,6 +102,11 @@ echo '::group::Install TorchVision' python setup.py develop echo '::endgroup::' +echo '::group::Install torchvision-extra-decoders' +# This can be done after torchvision was built +pip install torchvision-extra-decoders +echo '::endgroup::' + echo '::group::Collect environment information' conda list python -m torch.utils.collect_env diff --git a/docs/source/io.rst b/docs/source/io.rst index 6a76f95e897..c3f2d658014 100644 --- a/docs/source/io.rst +++ b/docs/source/io.rst @@ -9,8 +9,8 @@ images and videos. Image Decoding -------------- -Torchvision currently supports decoding JPEG, PNG, WEBP and GIF images. JPEG -decoding can also be done on CUDA GPUs. +Torchvision currently supports decoding JPEG, PNG, WEBP, GIF, AVIF, and HEIC +images. JPEG decoding can also be done on CUDA GPUs. The main entry point is the :func:`~torchvision.io.decode_image` function, which you can use as an alternative to ``PIL.Image.open()``. It will decode images @@ -30,9 +30,10 @@ run transforms/preproc natively on tensors. :func:`~torchvision.io.decode_image` will automatically detect the image format, -and call the corresponding decoder. You can also use the lower-level -format-specific decoders which can be more powerful, e.g. if you want to -encode/decode JPEGs on CUDA. +and call the corresponding decoder (except for HEIC and AVIF images, see details +in :func:`~torchvision.io.decode_avif` and :func:`~torchvision.io.decode_heic`). +You can also use the lower-level format-specific decoders which can be more +powerful, e.g. if you want to encode/decode JPEGs on CUDA. .. autosummary:: :toctree: generated/ @@ -41,8 +42,10 @@ encode/decode JPEGs on CUDA. decode_image decode_jpeg encode_png - decode_gif decode_webp + decode_avif + decode_heic + decode_gif .. autosummary:: :toctree: generated/ diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 4bb18cf6b48..d2fed552c4f 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -508,11 +508,20 @@ are combining pairs of images together. These can be used after the dataloader Developer tools ^^^^^^^^^^^^^^^ +.. autosummary:: + :toctree: generated/ + :template: class.rst + + v2.Transform + .. autosummary:: :toctree: generated/ :template: function.rst v2.functional.register_kernel + v2.query_size + v2.query_chw + v2.get_bounding_boxes V1 API Reference diff --git a/gallery/transforms/plot_custom_transforms.py b/gallery/transforms/plot_custom_transforms.py index 19bc955b934..d1bd9455bfb 100644 --- a/gallery/transforms/plot_custom_transforms.py +++ b/gallery/transforms/plot_custom_transforms.py @@ -12,6 +12,8 @@ """ # %% +from typing import Any, Dict, List + import torch from torchvision import tv_tensors from torchvision.transforms import v2 @@ -89,33 +91,110 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured # A key feature of the builtin Torchvision V2 transforms is that they can accept # arbitrary input structure and return the same structure as output (with # transformed entries). For example, transforms can accept a single image, or a -# tuple of ``(img, label)``, or an arbitrary nested dictionary as input: +# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's +# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`: structured_input = { "img": img, "annotations": (bboxes, label), - "something_that_will_be_ignored": (1, "hello") + "something that will be ignored": (1, "hello"), + "another tensor that is ignored": torch.arange(10), } structured_output = v2.RandomHorizontalFlip(p=1)(structured_input) assert isinstance(structured_output, dict) -assert structured_output["something_that_will_be_ignored"] == (1, "hello") +assert structured_output["something that will be ignored"] == (1, "hello") +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") +print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") + +# %% +# Basics: override the `transform()` method +# ----------------------------------------- +# +# In order to support arbitrary inputs in your custom transform, you will need +# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the +# `.transform()` method (not the `forward()` method!). Below is a basic example: + + +class MyCustomTransform(v2.Transform): + def transform(self, inpt: Any, params: Dict[str, Any]): + if type(inpt) == torch.Tensor: + print(f"I'm transforming an image of shape {inpt.shape}") + return inpt + 1 # dummy transformation + elif isinstance(inpt, tv_tensors.BoundingBoxes): + print(f"I'm transforming bounding boxes! {inpt.canvas_size = }") + return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation + + +my_custom_transform = MyCustomTransform() +structured_output = my_custom_transform(structured_input) + +assert isinstance(structured_output, dict) +assert structured_output["something that will be ignored"] == (1, "hello") +assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all() +print(f"The input bboxes are:\n{structured_input['annotations'][0]}") print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}") # %% -# If you want to reproduce this behavior in your own transform, we invite you to -# look at our `code -# `_ -# and adapt it to your needs. -# -# In brief, the core logic is to unpack the input into a flat list using `pytree -# `_, and -# then transform only the entries that can be transformed (the decision is made -# based on the **class** of the entries, as all TVTensors are -# tensor-subclasses) plus some custom logic that is out of score here - check the -# code for details. The (potentially transformed) entries are then repacked and -# returned, in the same structure as the input. -# -# We do not provide public dev-facing tools to achieve that at this time, but if -# this is something that would be valuable to you, please let us know by opening -# an issue on our `GitHub repo `_. +# An important thing to note is that when we call ``my_custom_transform`` on +# ``structured_input``, the input is flattened and then each individual part is +# passed to ``transform()``. That is, ``transform()``` receives the input image, +# then the bounding boxes, etc. Within ``transform()``, you can decide how to +# transform each input, based on their type. +# +# If you're curious why the other tensor (``torch.arange()``) didn't get passed +# to ``transform()``, see :ref:`this note ` for more +# details. +# +# Advanced: The ``make_params()`` method +# -------------------------------------- +# +# The ``make_params()`` method is called internally before calling +# ``transform()`` on each input. This is typically useful to generate random +# parameter values. In the example below, we use it to randomly apply the +# transformation with a probability of 0.5 + + +class MyRandomTransform(MyCustomTransform): + def __init__(self, p=0.5): + self.p = p + super().__init__() + + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + apply_transform = (torch.rand(size=(1,)) < self.p).item() + params = dict(apply_transform=apply_transform) + return params + + def transform(self, inpt: Any, params: Dict[str, Any]): + if not params["apply_transform"]: + print("Not transforming anything!") + return inpt + else: + return super().transform(inpt, params) + + +my_random_transform = MyRandomTransform() + +torch.manual_seed(0) +_ = my_random_transform(structured_input) # transforms +_ = my_random_transform(structured_input) # doesn't transform + +# %% +# +# .. note:: +# +# It's important for such random parameter generation to happen within +# ``make_params()`` and not within ``transform()``, so that for a given +# transform call, the same RNG applies to all the inputs in the same way. If +# we were to perform the RNG within ``transform()``, we would risk e.g. +# transforming the image while *not* transforming the bounding boxes. +# +# The ``make_params()`` method takes the list of all the inputs as parameter +# (each of the elements in this list will later be pased to ``transform()``). +# You can use ``flat_inputs`` to e.g. figure out the dimensions on the input, +# using :func:`~torchvision.transforms.v2.query_chw` or +# :func:`~torchvision.transforms.v2.query_size`. +# +# ``make_params()`` should return a dict (or actually, anything you want) that +# will then be passed to ``transform()``. diff --git a/mypy.ini b/mypy.ini index d6f3cb16963..d8ab11d0d21 100644 --- a/mypy.ini +++ b/mypy.ini @@ -45,6 +45,10 @@ ignore_errors = True ignore_errors=True +[mypy-torchvision.models.maxvit.*] + +ignore_errors=True + [mypy-torchvision.models.detection.anchor_utils] ignore_errors = True diff --git a/packaging/post_build_script.sh b/packaging/post_build_script.sh index ae7542f9f8a..253980b98c3 100644 --- a/packaging/post_build_script.sh +++ b/packaging/post_build_script.sh @@ -1,2 +1,4 @@ #!/bin/bash LD_LIBRARY_PATH="/usr/local/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH" python packaging/wheel/relocate.py + +pip install torchvision-extra-decoders diff --git a/references/segmentation/v2_extras.py b/references/segmentation/v2_extras.py index e1a8b53e02b..2d9eb3e661a 100644 --- a/references/segmentation/v2_extras.py +++ b/references/segmentation/v2_extras.py @@ -10,13 +10,13 @@ def __init__(self, size, fill=0): self.size = size self.fill = v2._utils._setup_fill_arg(fill) - def _get_params(self, sample): + def make_params(self, sample): _, height, width = v2._utils.query_chw(sample) padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] needs_padding = any(padding) return dict(padding=padding, needs_padding=needs_padding) - def _transform(self, inpt, params): + def transform(self, inpt, params): if not params["needs_padding"]: return inpt diff --git a/setup.py b/setup.py index c2be57e9775..956682e7ead 100644 --- a/setup.py +++ b/setup.py @@ -19,8 +19,6 @@ USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1" USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" -USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "0") == "1" # TODO enable by default! -USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default! USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) # Note: the GPU video decoding stuff used to be called "video codec", which @@ -51,8 +49,6 @@ print(f"{USE_PNG = }") print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") -print(f"{USE_HEIC = }") -print(f"{USE_AVIF = }") print(f"{USE_NVJPEG = }") print(f"{NVCC_FLAGS = }") print(f"{USE_CPU_VIDEO_DECODER = }") @@ -336,36 +332,6 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") - if USE_HEIC: - heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif.h") - if heic_found: - print("Building torchvision with HEIC support") - print(f"{heic_include_dir = }") - print(f"{heic_library_dir = }") - if heic_include_dir is not None and heic_library_dir is not None: - # if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add. - include_dirs.append(heic_include_dir) - library_dirs.append(heic_library_dir) - libraries.append("heif") - define_macros += [("HEIC_FOUND", 1)] - else: - warnings.warn("Building torchvision without HEIC support") - - if USE_AVIF: - avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h") - if avif_found: - print("Building torchvision with AVIF support") - print(f"{avif_include_dir = }") - print(f"{avif_library_dir = }") - if avif_include_dir is not None and avif_library_dir is not None: - # if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add. - include_dirs.append(avif_include_dir) - library_dirs.append(avif_library_dir) - libraries.append("avif") - define_macros += [("AVIF_FOUND", 1)] - else: - warnings.warn("Building torchvision without AVIF support") - if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() diff --git a/test/smoke_test.py b/test/smoke_test.py index 3a44ae3efe9..38f0054e6b6 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -6,7 +6,7 @@ import torch import torchvision -from torchvision.io import decode_image, decode_jpeg, decode_webp, read_file +from torchvision.io import decode_avif, decode_heic, decode_image, decode_jpeg, read_file from torchvision.models import resnet50, ResNet50_Weights @@ -24,13 +24,46 @@ def smoke_test_torchvision_read_decode() -> None: img_jpg = decode_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) if img_jpg.shape != (3, 606, 517): raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}") + img_png = decode_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png")) if img_png.shape != (4, 471, 354): raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") + img_webp = decode_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp")) if img_webp.shape != (3, 100, 100): raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}") + if sys.platform == "linux": + pass + # TODO: Fix/uncomment below (the TODO below is mostly accurate but we're + # still observing some failures on some CUDA jobs. Most are working.) + # if torch.cuda.is_available(): + # # TODO: For whatever reason this only passes on the runners that + # # support CUDA. + # # Strangely, on the CPU runners where this fails, the AVIF/HEIC + # # tests (ran with pytest) are passing. This is likely related to a + # # libcxx symbol thing, and the proper libstdc++.so get loaded only + # # with pytest? Ugh. + # img_avif = decode_avif(read_file(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.avif"))) + # if img_avif.shape != (3, 100, 100): + # raise RuntimeError(f"Unexpected shape of img_avif: {img_avif.shape}") + + # img_heic = decode_heic( + # read_file(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic")) + # ) + # if img_heic.shape != (3, 100, 100): + # raise RuntimeError(f"Unexpected shape of img_heic: {img_heic.shape}") + else: + try: + decode_avif(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.avif")) + except RuntimeError as e: + assert "torchvision-extra-decoders" in str(e) + + try: + decode_heic(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic")) + except RuntimeError as e: + assert "torchvision-extra-decoders" in str(e) + def smoke_test_torchvision_decode_jpeg(device: str = "cpu"): img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) diff --git a/test/test_image.py b/test/test_image.py index 4146d54ac78..b8e96773267 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -4,7 +4,6 @@ import os import re import sys -from contextlib import nullcontext from pathlib import Path import numpy as np @@ -14,11 +13,10 @@ import torchvision.transforms.v2.functional as F from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence -from torchvision._internally_replaced_utils import IN_FBCODE from torchvision.io.image import ( - _decode_avif, - _decode_heic, + decode_avif, decode_gif, + decode_heic, decode_image, decode_jpeg, decode_png, @@ -43,22 +41,11 @@ TOOSMALL_PNG = os.path.join(IMAGE_ROOT, "toosmall_png") IS_WINDOWS = sys.platform in ("win32", "cygwin") IS_MACOS = sys.platform == "darwin" +IS_LINUX = sys.platform == "linux" PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "") # See https://github.com/pytorch/vision/pull/8724#issuecomment-2503964558 -ROCM_WEBP_MESSAGE = "ROCM not built with webp support." - -# Hacky way of figuring out whether we compiled with libavif/libheif (those are -# currenlty disabled by default) -try: - _decode_avif(torch.arange(10, dtype=torch.uint8)) -except Exception as e: - DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str(e) - -try: - _decode_heic(torch.arange(10, dtype=torch.uint8)) -except Exception as e: - DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str(e) +HEIC_AVIF_MESSAGE = "AVIF and HEIF only available on linux." def _get_safe_image_name(name): @@ -866,19 +853,23 @@ def test_decode_gif(tmpdir, name, scripted): torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0) -decode_fun_and_match = [ - (decode_png, "Content is not png"), - (decode_jpeg, "Not a JPEG file"), - (decode_gif, re.escape("DGifOpenFileName() failed - 103")), - (decode_webp, "WebPGetFeatures failed."), -] -if DECODE_AVIF_ENABLED: - decode_fun_and_match.append((_decode_avif, "BMFF parsing failed")) -if DECODE_HEIC_ENABLED: - decode_fun_and_match.append((_decode_heic, "Invalid input: No 'ftyp' box")) - - -@pytest.mark.parametrize("decode_fun, match", decode_fun_and_match) +@pytest.mark.parametrize( + "decode_fun, match", + [ + (decode_png, "Content is not png"), + (decode_jpeg, "Not a JPEG file"), + (decode_gif, re.escape("DGifOpenFileName() failed - 103")), + (decode_webp, "WebPGetFeatures failed."), + pytest.param( + decode_avif, "BMFF parsing failed", marks=pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE) + ), + pytest.param( + decode_heic, + "Invalid input: No 'ftyp' box", + marks=pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE), + ), + ], +) def test_decode_bad_encoded_data(decode_fun, match): encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8) with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"): @@ -934,13 +925,10 @@ def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename) img += 123 # make sure image buffer wasn't freed by underlying decoding lib -@pytest.mark.skipif(not DECODE_AVIF_ENABLED, reason="AVIF support not enabled.") -@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) -@pytest.mark.parametrize("scripted", (False, True)) -def test_decode_avif(decode_fun, scripted): +@pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE) +@pytest.mark.parametrize("decode_fun", (decode_avif,)) +def test_decode_avif(decode_fun): encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif"))) - if scripted: - decode_fun = torch.jit.script(decode_fun) img = decode_fun(encoded_bytes) assert img.shape == (3, 100, 100) assert img[None].is_contiguous(memory_format=torch.channels_last) @@ -949,16 +937,8 @@ def test_decode_avif(decode_fun, scripted): # Note: decode_image fails because some of these files have a (valid) signature # we don't recognize. We should probably use libmagic.... -decode_funs = [] -if DECODE_AVIF_ENABLED: - decode_funs.append(_decode_avif) -if DECODE_HEIC_ENABLED: - decode_funs.append(_decode_heic) - - -@pytest.mark.skipif(not decode_funs, reason="Built without avif and heic support.") -@pytest.mark.parametrize("decode_fun", decode_funs) -@pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE) +@pytest.mark.parametrize("decode_fun", (decode_avif, decode_heic)) @pytest.mark.parametrize( "mode, pil_mode", ( @@ -970,7 +950,7 @@ def test_decode_avif(decode_fun, scripted): @pytest.mark.parametrize( "filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name ) -def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, filename): +def test_decode_avif_heic_against_pil(decode_fun, mode, pil_mode, filename): if "reversed_dimg_order" in str(filename): # Pillow properly decodes this one, but we don't (order of parts of the # image is wrong). This is due to a bug that was recently fixed in @@ -980,8 +960,6 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file import pillow_avif # noqa encoded_bytes = read_file(filename) - if scripted: - decode_fun = torch.jit.script(decode_fun) try: img = decode_fun(encoded_bytes, mode=mode) except RuntimeError as e: @@ -994,6 +972,7 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file "no 'ispe' property", "'iref' has double references", "Invalid image grid", + "decode_heif failed: Invalid input: No 'meta' box", ) ): pytest.skip(reason="Expected failure, that's OK") @@ -1010,7 +989,7 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file try: from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode)) except RuntimeError as e: - if "Invalid image grid" in str(e): + if any(s in str(e) for s in ("Invalid image grid", "Failed to decode image: Not implemented")): pytest.skip(reason="PIL failure") else: raise e @@ -1021,7 +1000,7 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file g = make_grid([img, from_pil]) F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png")) - is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic" + is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "decode_heic" if mode == ImageReadMode.RGB and not is_decode_heic: # We don't compare torchvision's AVIF against PIL for RGB because # results look pretty different on RGBA images (other images are fine). @@ -1035,13 +1014,10 @@ def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, file torch.testing.assert_close(img, from_pil, rtol=0, atol=3) -@pytest.mark.skipif(not DECODE_HEIC_ENABLED, reason="HEIC support not enabled yet.") -@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image)) -@pytest.mark.parametrize("scripted", (False, True)) -def test_decode_heic(decode_fun, scripted): +@pytest.mark.skipif(not IS_LINUX, reason=HEIC_AVIF_MESSAGE) +@pytest.mark.parametrize("decode_fun", (decode_heic,)) +def test_decode_heic(decode_fun): encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic"))) - if scripted: - decode_fun = torch.jit.script(decode_fun) img = decode_fun(encoded_bytes) assert img.shape == (3, 100, 100) assert img[None].is_contiguous(memory_format=torch.channels_last) @@ -1080,13 +1056,5 @@ def test_mode_str(): assert decode_image(path, mode="RGBA").shape[0] == 4 -def test_avif_heic_fbcode(): - cm = nullcontext() if IN_FBCODE else pytest.raises(ImportError, match="cannot import") - with cm: - from torchvision.io import decode_heic # noqa - with cm: - from torchvision.io import decode_avif # noqa - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 3f2e5015863..85ef98cf7b8 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -159,7 +159,7 @@ def test__copy_paste(self, label_type): class TestFixedSizeCrop: - def test__get_params(self, mocker): + def test_make_params(self, mocker): crop_size = (7, 7) batch_shape = (10,) canvas_size = (11, 5) @@ -170,7 +170,7 @@ def test__get_params(self, mocker): make_image(size=canvas_size, color_space="RGB"), make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_shape[0]), ] - params = transform._get_params(flat_inputs) + params = transform.make_params(flat_inputs) assert params["needs_crop"] assert params["height"] <= crop_size[0] @@ -191,7 +191,7 @@ def test__transform_culling(self, mocker): is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + "torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params", return_value=dict( needs_crop=True, top=0, @@ -229,7 +229,7 @@ def test__transform_bounding_boxes_clamping(self, mocker): canvas_size = (10, 10) mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", + "torchvision.prototype.transforms._geometry.FixedSizeCrop.make_params", return_value=dict( needs_crop=True, top=0, diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index e16c0677c9f..fb49525ecfe 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1355,7 +1355,7 @@ def test_transform_bounding_boxes_correctness(self, format, center, seed): transform = transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center) torch.manual_seed(seed) - params = transform._get_params([bounding_boxes]) + params = transform.make_params([bounding_boxes]) torch.manual_seed(seed) actual = transform(bounding_boxes) @@ -1369,14 +1369,14 @@ def test_transform_bounding_boxes_correctness(self, format, center, seed): @pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["scale"]) @pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["shear"]) @pytest.mark.parametrize("seed", list(range(10))) - def test_transform_get_params_bounds(self, degrees, translate, scale, shear, seed): + def test_transformmake_params_bounds(self, degrees, translate, scale, shear, seed): image = make_image() height, width = F.get_size(image) transform = transforms.RandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) torch.manual_seed(seed) - params = transform._get_params([image]) + params = transform.make_params([image]) if isinstance(degrees, (int, float)): assert -degrees <= params["angle"] <= degrees @@ -1783,7 +1783,7 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed transform = transforms.RandomRotation(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, expand=expand, center=center) torch.manual_seed(seed) - params = transform._get_params([bounding_boxes]) + params = transform.make_params([bounding_boxes]) torch.manual_seed(seed) actual = transform(bounding_boxes) @@ -1795,11 +1795,11 @@ def test_transform_bounding_boxes_correctness(self, format, expand, center, seed @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"]) @pytest.mark.parametrize("seed", list(range(10))) - def test_transform_get_params_bounds(self, degrees, seed): + def test_transformmake_params_bounds(self, degrees, seed): transform = transforms.RandomRotation(degrees=degrees) torch.manual_seed(seed) - params = transform._get_params([]) + params = transform.make_params([]) if isinstance(degrees, (int, float)): assert -degrees <= params["angle"] <= degrees @@ -1843,7 +1843,7 @@ def test_functional_image_fast_path_correctness(self, size, angle, expand): class TestContainerTransforms: class BuiltinTransform(transforms.Transform): - def _transform(self, inpt, params): + def transform(self, inpt, params): return inpt class PackedInputTransform(nn.Module): @@ -2996,7 +2996,7 @@ def test_transform_bounding_boxes_correctness(self, output_size, format, dtype, with freeze_rng_state(): torch.manual_seed(seed) - params = transform._get_params([bounding_boxes]) + params = transform.make_params([bounding_boxes]) assert not params.pop("needs_pad") del params["padding"] assert params.pop("needs_crop") @@ -3129,9 +3129,9 @@ def test_transform_image_correctness(self, param, value, dtype, device, seed): with freeze_rng_state(): torch.manual_seed(seed) - # This emulates the random apply check that happens before _get_params is called + # This emulates the random apply check that happens before make_params is called torch.rand(1) - params = transform._get_params([image]) + params = transform.make_params([image]) torch.manual_seed(seed) actual = transform(image) @@ -3159,7 +3159,7 @@ def test_transform_errors(self): transform = transforms.RandomErasing(value=[1, 2, 3, 4]) with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"): - transform._get_params([make_image()]) + transform.make_params([make_image()]) class TestGaussianBlur: @@ -3244,9 +3244,9 @@ def test_assertions(self): transforms.GaussianBlur(3, sigma={}) @pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0], (10, 12.0), [10]]) - def test__get_params(self, sigma): + def test_make_params(self, sigma): transform = transforms.GaussianBlur(3, sigma=sigma) - params = transform._get_params([]) + params = transform.make_params([]) if isinstance(sigma, float): assert params["sigma"][0] == params["sigma"][1] == sigma @@ -5251,7 +5251,7 @@ def test_transform_params_correctness(self, side_range, make_input, device): input = make_input() height, width = F.get_size(input) - params = transform._get_params([input]) + params = transform.make_params([input]) assert "padding" in params padding = params["padding"] @@ -5305,13 +5305,13 @@ def test_transform(self, make_input, device): check_transform(transforms.ScaleJitter(self.TARGET_SIZE), make_input(self.INPUT_SIZE, device=device)) - def test__get_params(self): + def test_make_params(self): input_size = self.INPUT_SIZE target_size = self.TARGET_SIZE scale_range = (0.5, 1.5) transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range) - params = transform._get_params([make_image(input_size)]) + params = transform.make_params([make_image(input_size)]) assert "size" in params size = params["size"] @@ -5544,7 +5544,7 @@ def split_on_pure_tensor(to_split): return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others class CopyCloneTransform(transforms.Transform): - def _transform(self, inpt, params): + def transform(self, inpt, params): return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy() @staticmethod @@ -5580,7 +5580,7 @@ def was_applied(output, inpt): class TestRandomIoUCrop: @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) - def test__get_params(self, device, options): + def test_make_params(self, device, options): orig_h, orig_w = size = (24, 32) image = make_image(size) bboxes = tv_tensors.BoundingBoxes( @@ -5596,7 +5596,7 @@ def test__get_params(self, device, options): n_samples = 5 for _ in range(n_samples): - params = transform._get_params(sample) + params = transform.make_params(sample) if options == [2.0]: assert len(params) == 0 @@ -5622,8 +5622,8 @@ def test__transform_empty_params(self, mocker): bboxes = tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4)) label = torch.tensor([1]) sample = [image, bboxes, label] - # Let's mock transform._get_params to control the output: - transform._get_params = mocker.MagicMock(return_value={}) + # Let's mock transform.make_params to control the output: + transform.make_params = mocker.MagicMock(return_value={}) output = transform(sample) torch.testing.assert_close(output, sample) @@ -5648,7 +5648,7 @@ def test__transform(self, mocker): is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) - transform._get_params = mocker.MagicMock(return_value=params) + transform.make_params = mocker.MagicMock(return_value=params) output = transform(sample) # check number of bboxes vs number of labels: @@ -5662,13 +5662,13 @@ def test__transform(self, mocker): class TestRandomShortestSize: @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) - def test__get_params(self, min_size, max_size): + def test_make_params(self, min_size, max_size): canvas_size = (3, 10) transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size, antialias=True) sample = make_image(canvas_size) - params = transform._get_params([sample]) + params = transform.make_params([sample]) assert "size" in params size = params["size"] @@ -5685,14 +5685,14 @@ def test__get_params(self, min_size, max_size): class TestRandomResize: - def test__get_params(self): + def test_make_params(self): min_size = 3 max_size = 6 transform = transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) for _ in range(10): - params = transform._get_params([]) + params = transform.make_params([]) assert isinstance(params["size"], list) and len(params["size"]) == 1 size = params["size"][0] @@ -6148,12 +6148,12 @@ def test_transform_image_correctness(self, quality, color_space, seed): @pytest.mark.parametrize("quality", [5, (10, 20)]) @pytest.mark.parametrize("seed", list(range(10))) - def test_transform_get_params_bounds(self, quality, seed): + def test_transformmake_params_bounds(self, quality, seed): transform = transforms.JPEG(quality=quality) with freeze_rng_state(): torch.manual_seed(seed) - params = transform._get_params([]) + params = transform.make_params([]) if isinstance(quality, int): assert params["quality"] == quality diff --git a/torchvision/csrc/io/image/cpu/decode_avif.cpp b/torchvision/csrc/io/image/cpu/decode_avif.cpp deleted file mode 100644 index c3ecd581e42..00000000000 --- a/torchvision/csrc/io/image/cpu/decode_avif.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include "decode_avif.h" -#include "../common.h" - -#if AVIF_FOUND -#include "avif/avif.h" -#endif // AVIF_FOUND - -namespace vision { -namespace image { - -#if !AVIF_FOUND -torch::Tensor decode_avif( - const torch::Tensor& encoded_data, - ImageReadMode mode) { - TORCH_CHECK( - false, "decode_avif: torchvision not compiled with libavif support"); -} -#else - -// This normally comes from avif_cxx.h, but it's not always present when -// installing libavif. So we just copy/paste it here. -struct UniquePtrDeleter { - void operator()(avifDecoder* decoder) const { - avifDecoderDestroy(decoder); - } -}; -using DecoderPtr = std::unique_ptr; - -torch::Tensor decode_avif( - const torch::Tensor& encoded_data, - ImageReadMode mode) { - // This is based on - // https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c - // Refer there for more detail about what each function does, and which - // structure/data is available after which call. - - validate_encoded_data(encoded_data); - - DecoderPtr decoder(avifDecoderCreate()); - TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder."); - - auto result = AVIF_RESULT_UNKNOWN_ERROR; - result = avifDecoderSetIOMemory( - decoder.get(), encoded_data.data_ptr(), encoded_data.numel()); - TORCH_CHECK( - result == AVIF_RESULT_OK, - "avifDecoderSetIOMemory failed:", - avifResultToString(result)); - - result = avifDecoderParse(decoder.get()); - TORCH_CHECK( - result == AVIF_RESULT_OK, - "avifDecoderParse failed: ", - avifResultToString(result)); - TORCH_CHECK( - decoder->imageCount == 1, "Avif file contains more than one image"); - - result = avifDecoderNextImage(decoder.get()); - TORCH_CHECK( - result == AVIF_RESULT_OK, - "avifDecoderNextImage failed:", - avifResultToString(result)); - - avifRGBImage rgb; - memset(&rgb, 0, sizeof(rgb)); - avifRGBImageSetDefaults(&rgb, decoder->image); - - // images encoded as 10 or 12 bits will be decoded as uint16. The rest are - // decoded as uint8. - auto use_uint8 = (decoder->image->depth <= 8); - rgb.depth = use_uint8 ? 8 : 16; - - auto return_rgb = - should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( - mode, decoder->alphaPresent); - - auto num_channels = return_rgb ? 3 : 4; - rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA; - rgb.ignoreAlpha = return_rgb ? AVIF_TRUE : AVIF_FALSE; - - auto out = torch::empty( - {rgb.height, rgb.width, num_channels}, - use_uint8 ? torch::kUInt8 : at::kUInt16); - rgb.pixels = (uint8_t*)out.data_ptr(); - rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb); - - result = avifImageYUVToRGB(decoder->image, &rgb); - TORCH_CHECK( - result == AVIF_RESULT_OK, - "avifImageYUVToRGB failed: ", - avifResultToString(result)); - - return out.permute({2, 0, 1}); // return CHW, channels-last -} -#endif // AVIF_FOUND - -} // namespace image -} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_avif.h b/torchvision/csrc/io/image/cpu/decode_avif.h deleted file mode 100644 index 7feee1adfcb..00000000000 --- a/torchvision/csrc/io/image/cpu/decode_avif.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once - -#include -#include "../common.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_avif( - const torch::Tensor& encoded_data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); - -} // namespace image -} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_heic.cpp b/torchvision/csrc/io/image/cpu/decode_heic.cpp deleted file mode 100644 index e245c25f9d7..00000000000 --- a/torchvision/csrc/io/image/cpu/decode_heic.cpp +++ /dev/null @@ -1,135 +0,0 @@ -#include "decode_heic.h" -#include "../common.h" - -#if HEIC_FOUND -#include "libheif/heif_cxx.h" -#endif // HEIC_FOUND - -namespace vision { -namespace image { - -#if !HEIC_FOUND -torch::Tensor decode_heic( - const torch::Tensor& encoded_data, - ImageReadMode mode) { - TORCH_CHECK( - false, "decode_heic: torchvision not compiled with libheif support"); -} -#else - -torch::Tensor decode_heic( - const torch::Tensor& encoded_data, - ImageReadMode mode) { - validate_encoded_data(encoded_data); - - auto return_rgb = true; - - int height = 0; - int width = 0; - int num_channels = 0; - int stride = 0; - uint8_t* decoded_data = nullptr; - heif::Image img; - int bit_depth = 0; - - try { - heif::Context ctx; - ctx.read_from_memory_without_copy( - encoded_data.data_ptr(), encoded_data.numel()); - - // TODO properly error on (or support) image sequences. Right now, I think - // this function will always return the first image in a sequence, which is - // inconsistent with decode_gif (which returns a batch) and with decode_avif - // (which errors loudly). - // Why? I'm struggling to make sense of - // ctx.get_number_of_top_level_images(). It disagrees with libavif's - // imageCount. For example on some of the libavif test images: - // - // - colors-animated-12bpc-keyframes-0-2-3.avif - // avif num images = 5 - // heif num images = 1 // Why is this 1 when clearly this is supposed to - // be a sequence? - // - sofa_grid1x5_420.avif - // avif num images = 1 - // heif num images = 6 // If we were to error here we won't be able to - // decode this image which is otherwise properly - // decoded by libavif. - // I can't find a libheif function that does what we need here, or at least - // that agrees with libavif. - - // TORCH_CHECK( - // ctx.get_number_of_top_level_images() == 1, - // "heic file contains more than one image"); - - heif::ImageHandle handle = ctx.get_primary_image_handle(); - bit_depth = handle.get_luma_bits_per_pixel(); - - return_rgb = - should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video( - mode, handle.has_alpha_channel()); - - height = handle.get_height(); - width = handle.get_width(); - - num_channels = return_rgb ? 3 : 4; - heif_chroma chroma; - if (bit_depth == 8) { - chroma = return_rgb ? heif_chroma_interleaved_RGB - : heif_chroma_interleaved_RGBA; - } else { - // TODO: This, along with our 10bits -> 16bits range mapping down below, - // may not work on BE platforms - chroma = return_rgb ? heif_chroma_interleaved_RRGGBB_LE - : heif_chroma_interleaved_RRGGBBAA_LE; - } - - img = handle.decode_image(heif_colorspace_RGB, chroma); - - decoded_data = img.get_plane(heif_channel_interleaved, &stride); - } catch (const heif::Error& err) { - // We need this try/catch block and call TORCH_CHECK, because libheif may - // otherwise throw heif::Error that would just be reported as "An unknown - // exception occurred" when we move back to Python. - TORCH_CHECK(false, "decode_heif failed: ", err.get_message()); - } - TORCH_CHECK(decoded_data != nullptr, "Something went wrong during decoding."); - - auto dtype = (bit_depth == 8) ? torch::kUInt8 : at::kUInt16; - auto out = torch::empty({height, width, num_channels}, dtype); - uint8_t* out_ptr = (uint8_t*)out.data_ptr(); - - // decoded_data is *almost* the raw decoded data, but not quite: for some - // images, there may be some padding at the end of each row, i.e. when stride - // != row_size_in_bytes. So we can't copy decoded_data into the tensor's - // memory directly, we have to copy row by row. Oh, and if you think you can - // take a shortcut when stride == row_size_in_bytes and just do: - // out = torch::from_blob(decoded_data, ...) - // you can't, because decoded_data is owned by the heif::Image object and it - // gets freed when it gets out of scope! - auto row_size_in_bytes = width * num_channels * ((bit_depth == 8) ? 1 : 2); - for (auto h = 0; h < height; h++) { - memcpy( - out_ptr + h * row_size_in_bytes, - decoded_data + h * stride, - row_size_in_bytes); - } - if (bit_depth > 8) { - // Say bit depth is 10. decodec_data and out_ptr contain 10bits values - // over 2 bytes, stored into uint16_t. In torchvision a uint16 value is - // expected to be in [0, 2**16), so we have to map the 10bits value to that - // range. Note that other libraries like libavif do that mapping - // automatically. - // TODO: It's possible to avoid the memcpy call above in this case, and do - // the copy at the same time as the conversation. Whether it's worth it - // should be benchmarked. - auto out_ptr_16 = (uint16_t*)out_ptr; - for (auto p = 0; p < height * width * num_channels; p++) { - out_ptr_16[p] <<= (16 - bit_depth); - } - } - return out.permute({2, 0, 1}); -} -#endif // HEIC_FOUND - -} // namespace image -} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_heic.h b/torchvision/csrc/io/image/cpu/decode_heic.h deleted file mode 100644 index 10b414f554d..00000000000 --- a/torchvision/csrc/io/image/cpu/decode_heic.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once - -#include -#include "../common.h" - -namespace vision { -namespace image { - -C10_EXPORT torch::Tensor decode_heic( - const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); - -} // namespace image -} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index 9c1a7ff3ef4..43a688604f6 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -1,8 +1,6 @@ #include "decode_image.h" -#include "decode_avif.h" #include "decode_gif.h" -#include "decode_heic.h" #include "decode_jpeg.h" #include "decode_png.h" #include "decode_webp.h" @@ -24,7 +22,7 @@ torch::Tensor decode_image( "Expected a non empty 1-dimensional tensor"); auto err_msg = - "Unsupported image file. Only jpeg, png and gif are currently supported."; + "Unsupported image file. Only jpeg, png, webp and gif are currently supported. For avif and heic format, please rely on `decode_avif` and `decode_heic` directly."; auto datap = data.data_ptr(); @@ -50,29 +48,6 @@ torch::Tensor decode_image( return decode_gif(data); } - // We assume the signature of an avif file is - // 0000 0020 6674 7970 6176 6966 - // xxxx xxxx f t y p a v i f - // We only check for the "ftyp avif" part. - // This is probably not perfect, but hopefully this should cover most files. - const uint8_t avif_signature[8] = { - 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif" - TORCH_CHECK(data.numel() >= 12, err_msg); - if ((memcmp(avif_signature, datap + 4, 8) == 0)) { - return decode_avif(data, mode); - } - - // Similarly for heic we assume the signature is "ftypeheic" but some files - // may come as "ftypmif1" where the "heic" part is defined later in the file. - // We can't be re-inventing libmagic here. We might need to start relying on - // it though... - const uint8_t heic_signature[8] = { - 0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63}; // == "ftypheic" - TORCH_CHECK(data.numel() >= 12, err_msg); - if ((memcmp(heic_signature, datap + 4, 8) == 0)) { - return decode_heic(data, mode); - } - const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" const uint8_t webp_signature_end[7] = { 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index f0ce91144a6..2ac29e6b1ee 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -23,10 +23,6 @@ static auto registry = &decode_jpeg) .op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor", &decode_webp) - .op("image::decode_heic(Tensor encoded_data, int mode) -> Tensor", - &decode_heic) - .op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor", - &decode_avif) .op("image::encode_jpeg", &encode_jpeg) .op("image::read_file", &read_file) .op("image::write_file", &write_file) diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index 23493f3c030..3f47fdec65c 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -1,8 +1,6 @@ #pragma once -#include "cpu/decode_avif.h" #include "cpu/decode_gif.h" -#include "cpu/decode_heic.h" #include "cpu/decode_image.h" #include "cpu/decode_jpeg.h" #include "cpu/decode_png.h" diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 0dcbd7e9cea..03bd5d23cb2 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -1,9 +1,3 @@ -from typing import Any, Dict, Iterator - -import torch - -from ..utils import _log_api_usage_once - try: from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER except ModuleNotFoundError: @@ -22,7 +16,9 @@ VideoMetaData, ) from .image import ( + decode_avif, decode_gif, + decode_heic, decode_image, decode_jpeg, decode_png, @@ -61,6 +57,7 @@ "decode_image", "decode_jpeg", "decode_png", + "decode_avif", "decode_heic", "decode_webp", "decode_gif", @@ -74,10 +71,3 @@ "Video", "VideoReader", ] - -from .._internally_replaced_utils import IN_FBCODE - -if IN_FBCODE: - from .image import _decode_avif as decode_avif, _decode_heic as decode_heic - - __all__ += ["decode_avif", "decode_heic"] diff --git a/torchvision/io/image.py b/torchvision/io/image.py index cb48d0e6816..023898f33c6 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -296,6 +296,12 @@ def decode_image( after this function to convert the decoded image into a uint8 or float tensor. + .. note:: + + ``decode_image()`` doesn't work yet on AVIF or HEIC images. For these + formats, directly call :func:`~torchvision.io.decode_avif` or + :func:`~torchvision.io.decode_heic`. + Args: input (Tensor or str or ``pathlib.Path``): The image to decode. If a tensor is passed, it must be one dimensional uint8 tensor containing @@ -377,12 +383,73 @@ def decode_webp( return torch.ops.image.decode_webp(input, mode.value) -def _decode_avif( - input: torch.Tensor, - mode: ImageReadMode = ImageReadMode.UNCHANGED, -) -> torch.Tensor: - """ - Decode an AVIF image into a 3 dimensional RGB[A] Tensor. +# TODO_AVIF_HEIC: Better support for torchscript. Scripting decode_avif of +# decode_heic currently fails, mainly because of the logic +# _load_extra_decoders_once() (using global variables, try/except statements, +# etc.). +# The ops (torch.ops.extra_decoders_ns.decode_*) are otherwise torchscript-able, +# and users who need torchscript can always just wrap those. + +# TODO_AVIF_HEIC: decode_image() should work for those. The key technical issue +# we have here is that the format detection logic of decode_image() is +# implemented in torchvision, and torchvision has zero knowledge of +# torchvision-extra-decoders, so we cannot call the AVIF/HEIC C++ decoders +# (those in torchvision-extra-decoders) from there. +# A trivial check that could be done within torchvision would be to check the +# file extension, if a path was passed. We could also just implement the +# AVIF/HEIC detection logic in Python as a fallback, if the file detection +# didn't find any format. In any case: properly determining whether a file is +# HEIC is far from trivial, and relying on libmagic would probably be best + + +_EXTRA_DECODERS_ALREADY_LOADED = False + + +def _load_extra_decoders_once(): + global _EXTRA_DECODERS_ALREADY_LOADED + if _EXTRA_DECODERS_ALREADY_LOADED: + return + + try: + import torchvision_extra_decoders + + # torchvision-extra-decoders only supports linux for now. BUT, users on + # e.g. MacOS can still install it: they will get the pure-python + # 0.0.0.dev version: + # https://pypi.org/project/torchvision-extra-decoders/0.0.0.dev0, which + # is a dummy version that was created to reserve the namespace on PyPI. + # We have to check that expose_extra_decoders() exists for those users, + # so we can properly error on non-Linux archs. + assert hasattr(torchvision_extra_decoders, "expose_extra_decoders") + except (AssertionError, ImportError) as e: + raise RuntimeError( + "In order to enable the AVIF and HEIC decoding capabilities of " + "torchvision, you need to `pip install torchvision-extra-decoders`. " + "Just install the package, you don't need to update your code. " + "This is only supported on Linux, and this feature is still in BETA stage. " + "Please let us know of any issue: https://github.com/pytorch/vision/issues/new/choose. " + "Note that `torchvision-extra-decoders` is released under the LGPL license. " + ) from e + + # This will expose torch.ops.extra_decoders_ns.decode_avif and torch.ops.extra_decoders_ns.decode_heic + torchvision_extra_decoders.expose_extra_decoders() + + _EXTRA_DECODERS_ALREADY_LOADED = True + + +def decode_avif(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: + """Decode an AVIF image into a 3 dimensional RGB[A] Tensor. + + .. warning:: + In order to enable the AVIF decoding capabilities of torchvision, you + first need to run ``pip install torchvision-extra-decoders``. Just + install the package, you don't need to update your code. This is only + supported on Linux, and this feature is still in BETA stage. Please let + us know of any issue: + https://github.com/pytorch/vision/issues/new/choose. Note that + `torchvision-extra-decoders + `_ is + released under the LGPL license. The values of the output tensor are in uint8 in [0, 255] for most images. If the image has a bit-depth of more than 8, then the output tensor is uint16 @@ -401,16 +468,25 @@ def _decode_avif( Returns: Decoded image (Tensor[image_channels, image_height, image_width]) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(_decode_avif) - if isinstance(mode, str): - mode = ImageReadMode[mode.upper()] - return torch.ops.image.decode_avif(input, mode.value) - - -def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: - """ - Decode an HEIC image into a 3 dimensional RGB[A] Tensor. + _load_extra_decoders_once() + if input.dtype != torch.uint8: + raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}") + return torch.ops.extra_decoders_ns.decode_avif(input, mode.value) + + +def decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: + """Decode an HEIC image into a 3 dimensional RGB[A] Tensor. + + .. warning:: + In order to enable the AVIF decoding capabilities of torchvision, you + first need to run ``pip install torchvision-extra-decoders``. Just + install the package, you don't need to update your code. This is only + supported on Linux, and this feature is still in BETA stage. Please let + us know of any issue: + https://github.com/pytorch/vision/issues/new/choose. Note that + `torchvision-extra-decoders + `_ is + released under the LGPL license. The values of the output tensor are in uint8 in [0, 255] for most images. If the image has a bit-depth of more than 8, then the output tensor is uint16 @@ -429,8 +505,7 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN Returns: Decoded image (Tensor[image_channels, image_height, image_width]) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(_decode_heic) - if isinstance(mode, str): - mode = ImageReadMode[mode.upper()] - return torch.ops.image.decode_heic(input, mode.value) + _load_extra_decoders_once() + if input.dtype != torch.uint8: + raise RuntimeError(f"Input tensor must have uint8 data type, got {input.dtype}") + return torch.ops.extra_decoders_ns.decode_heic(input, mode.value) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 9f768ed555d..2e3dbed65a2 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -26,6 +26,10 @@ install PyAV on your system. """ ) + try: + FFmpegError = av.FFmpegError # from av 14 https://github.com/PyAV-Org/PyAV/blob/main/CHANGELOG.rst + except AttributeError: + FFmpegError = av.AVError except ImportError: av = ImportError( """\ @@ -155,7 +159,13 @@ def write_video( for img in video_array: frame = av.VideoFrame.from_ndarray(img, format="rgb24") - frame.pict_type = "NONE" + try: + frame.pict_type = "NONE" + except TypeError: + from av.video.frame import PictureType # noqa + + frame.pict_type = PictureType.NONE + for packet in stream.encode(frame): container.mux(packet) @@ -215,7 +225,7 @@ def _read_from_stream( try: # TODO check if stream needs to always be the video stream here or not container.seek(seek_offset, any_frame=False, backward=True, stream=stream) - except av.AVError: + except FFmpegError: # TODO add some warnings in this case # print("Corrupted file?", container.name) return [] @@ -228,7 +238,7 @@ def _read_from_stream( buffer_count += 1 continue break - except av.AVError: + except FFmpegError: # TODO add a warning pass # ensure that the results are sorted wrt the pts @@ -350,7 +360,7 @@ def read_video( ) info["audio_fps"] = container.streams.audio[0].rate - except av.AVError: + except FFmpegError: # TODO raise a warning? pass @@ -441,10 +451,10 @@ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[in video_time_base = video_stream.time_base try: pts = _decode_video_timestamps(container) - except av.AVError: + except FFmpegError: warnings.warn(f"Failed decoding frames for file {filename}") video_fps = float(video_stream.average_rate) - except av.AVError as e: + except FFmpegError as e: msg = f"Failed to open container for {filename}; Caught error: {e}" warnings.warn(msg, RuntimeWarning) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index e1f9c630939..91918026c97 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -443,7 +443,13 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: Compute the bounding boxes around the provided masks. Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with - ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + ``0 <= x1 <= x2`` and ``0 <= y1 <= y2``. + + .. warning:: + + In most cases the output will guarantee ``x1 < x2`` and ``y1 < y2``. But + if the input is degenerate, e.g. if a mask is a single row or a single + column, then the output may have x1 = x2 or y1 = y2. Args: masks (Tensor[N, H, W]): masks to transform where N is the number of masks diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index b04e1fe5a2a..e7c501aabe0 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -35,7 +35,7 @@ def __init__( self.padding_mode = padding_mode - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: if not has_any( flat_inputs, PIL.Image.Image, @@ -53,7 +53,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." ) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) new_height = min(height, self.crop_height) new_width = min(width, self.crop_width) @@ -107,7 +107,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: needs_pad=needs_pad, ) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_crop"]: inpt = self._call_kernel( F.crop, diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index bab2c70812e..6ea6256b171 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -39,7 +39,7 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]] ) self.dims = dims - def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: return inpt.as_subclass(torch.Tensor) @@ -61,7 +61,7 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i ) self.dims = dims - def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: dims = self.dims[type(inpt)] if dims is None: return inpt.as_subclass(torch.Tensor) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 3532abb3759..025cd13a766 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -15,7 +15,7 @@ def __init__(self, num_categories: int = -1): super().__init__() self.num_categories = num_categories - def _transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel: + def transform(self, inpt: proto_tv_tensors.Label, params: Dict[str, Any]) -> proto_tv_tensors.OneHotLabel: num_categories = self.num_categories if num_categories == -1 and inpt.categories is not None: num_categories = len(inpt.categories) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index b1dd5083408..93d4ba45d65 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -96,7 +96,7 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An ) return super()._call_kernel(functional, inpt, *args, **kwargs) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: img_c, img_h, img_w = query_chw(flat_inputs) if self.value is not None and not (len(self.value) in (1, img_c)): @@ -134,7 +134,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(i=i, j=j, h=h, w=w, v=v) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["v"] is not None: inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace) @@ -181,7 +181,7 @@ def forward(self, *inputs): params = { "labels": labels, "batch_size": labels.shape[0], - **self._get_params( + **self.make_params( [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] ), } @@ -190,7 +190,7 @@ def forward(self, *inputs): # after an image or video. However, we need to handle them in _transform, so we make sure to set them to True needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True flat_outputs = [ - self._transform(inpt, params) if needs_transform else inpt + self.transform(inpt, params) if needs_transform else inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) ] @@ -243,10 +243,10 @@ class MixUp(_BaseMixUpCutMix): It can also be a callable that takes the same input as the transform, and returns the labels. """ - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: lam = params["lam"] if inpt is params["labels"]: @@ -292,7 +292,7 @@ class CutMix(_BaseMixUpCutMix): It can also be a callable that takes the same input as the transform, and returns the labels. """ - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: lam = float(self._dist.sample(())) # type: ignore[arg-type] H, W = query_size(flat_inputs) @@ -314,7 +314,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if inpt is params["labels"]: return self._mixup_label(inpt, lam=params["lam_adjusted"]) elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt): @@ -361,9 +361,9 @@ def __init__(self, quality: Union[int, Sequence[int]]): self.quality = quality - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item() return dict(quality=quality) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.jpeg, inpt, quality=params["quality"]) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 49b4a8d8b10..7a471e7c1f6 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -25,7 +25,7 @@ def __init__(self, num_output_channels: int = 1): super().__init__() self.num_output_channels = num_output_channels - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) @@ -46,11 +46,11 @@ class RandomGrayscale(_RandomApplyTransform): def __init__(self, p: float = 0.1) -> None: super().__init__(p=p) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_input_channels, *_ = query_chw(flat_inputs) return dict(num_input_channels=num_input_channels) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) @@ -64,7 +64,7 @@ class RGB(Transform): def __init__(self): super().__init__() - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.grayscale_to_rgb, inpt) @@ -142,7 +142,7 @@ def _check_input( def _generate_value(left: float, right: float) -> float: return torch.empty(1).uniform_(left, right).item() - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: fn_idx = torch.randperm(4) b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) @@ -152,7 +152,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = inpt brightness_factor = params["brightness_factor"] contrast_factor = params["contrast_factor"] @@ -173,11 +173,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class RandomChannelPermutation(Transform): """Randomly permute the channels of an image or video""" - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) return dict(permutation=torch.randperm(num_channels)) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.permute_channels, inpt, params["permutation"]) @@ -220,7 +220,7 @@ def __init__( self.saturation = saturation self.p = p - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: num_channels, *_ = query_chw(flat_inputs) params: Dict[str, Any] = { key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None @@ -235,7 +235,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None return params - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["brightness_factor"] is not None: inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"]) if params["contrast_factor"] is not None and params["contrast_before"]: @@ -264,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomEqualize - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.equalize, inpt) @@ -281,7 +281,7 @@ class RandomInvert(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomInvert - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.invert, inpt) @@ -304,7 +304,7 @@ def __init__(self, bits: int, p: float = 0.5) -> None: super().__init__(p=p) self.bits = bits - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.posterize, inpt, bits=self.bits) @@ -332,7 +332,7 @@ def __init__(self, threshold: float, p: float = 0.5) -> None: super().__init__(p=p) self.threshold = threshold - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.solarize, inpt, threshold=self.threshold) @@ -349,7 +349,7 @@ class RandomAutocontrast(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomAutocontrast - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.autocontrast, inpt) @@ -372,5 +372,5 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: super().__init__(p=p) self.sharpness_factor = sharpness_factor - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) diff --git a/torchvision/transforms/v2/_deprecated.py b/torchvision/transforms/v2/_deprecated.py index a664cb3fbbd..51a4f076e49 100644 --- a/torchvision/transforms/v2/_deprecated.py +++ b/torchvision/transforms/v2/_deprecated.py @@ -46,5 +46,5 @@ def __init__(self) -> None: ) super().__init__() - def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor: return _F.to_tensor(inpt) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 5d6b1841d7f..c2461418a42 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -62,7 +62,7 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) @@ -156,7 +156,7 @@ def __init__( self.max_size = max_size self.antialias = antialias - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resize, inpt, @@ -189,7 +189,7 @@ def __init__(self, size: Union[int, Sequence[int]]): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.center_crop, inpt, output_size=self.size) @@ -268,7 +268,7 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) area = height * width @@ -306,7 +306,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias ) @@ -363,10 +363,10 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An ) return super()._call_kernel(functional, inpt, *args, **kwargs) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.five_crop, inpt, self.size) - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask): raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") @@ -408,11 +408,11 @@ def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: An ) return super()._call_kernel(functional, inpt, *args, **kwargs) - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, tv_tensors.BoundingBoxes, tv_tensors.Mask): raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip) @@ -483,7 +483,7 @@ def __init__( self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] @@ -535,7 +535,7 @@ def __init__( if side_range[0] < 1.0 or side_range[0] > side_range[1]: raise ValueError(f"Invalid side range provided {side_range}.") - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) @@ -551,7 +551,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(padding=padding) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel(F.pad, inpt, **params, fill=fill) @@ -618,11 +618,11 @@ def __init__( self.center = center - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() return dict(angle=angle) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel( F.rotate, @@ -716,7 +716,7 @@ def __init__( self.center = center - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() @@ -743,7 +743,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel( F.affine, @@ -839,7 +839,7 @@ def __init__( self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: padded_height, padded_width = query_size(flat_inputs) if self.padding is not None: @@ -897,7 +897,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: padding=padding, ) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = _get_fill(self._fill, type(inpt)) inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) @@ -952,7 +952,7 @@ def __init__( self.fill = fill self._fill = _setup_fill_arg(fill) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_size(flat_inputs) distortion_scale = self.distortion_scale @@ -982,7 +982,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) return dict(coefficients=perspective_coeffs) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel( F.perspective, @@ -1051,7 +1051,7 @@ def __init__( self.fill = fill self._fill = _setup_fill_arg(fill) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: size = list(query_size(flat_inputs)) dx = torch.rand([1, 1] + size) * 2 - 1 @@ -1074,7 +1074,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) return self._call_kernel( F.elastic, @@ -1132,7 +1132,7 @@ def __init__( self.options = sampler_options self.trials = trials - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: if not ( has_all(flat_inputs, tv_tensors.BoundingBoxes) and has_any(flat_inputs, PIL.Image.Image, tv_tensors.Image, is_pure_tensor) @@ -1142,7 +1142,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: "and bounding boxes. Sample can also contain masks." ) - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_size(flat_inputs) bboxes = get_bounding_boxes(flat_inputs) @@ -1194,7 +1194,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if len(params) < 1: return inpt @@ -1262,7 +1262,7 @@ def __init__( self.interpolation = interpolation self.antialias = antialias - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) @@ -1272,7 +1272,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1327,7 +1327,7 @@ def __init__( self.interpolation = interpolation self.antialias = antialias - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_height, orig_width = query_size(flat_inputs) min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] @@ -1340,7 +1340,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1406,11 +1406,11 @@ def __init__( self.interpolation = interpolation self.antialias = antialias - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: size = int(torch.randint(self.min_size, self.max_size, ())) return dict(size=[size]) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel( F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias ) diff --git a/torchvision/transforms/v2/_meta.py b/torchvision/transforms/v2/_meta.py index 01a356f46f5..1890b43115a 100644 --- a/torchvision/transforms/v2/_meta.py +++ b/torchvision/transforms/v2/_meta.py @@ -19,7 +19,7 @@ def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None: super().__init__() self.format = format - def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: + def transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value, arg-type] @@ -32,5 +32,5 @@ class ClampBoundingBoxes(Transform): _transformed_types = (tv_tensors.BoundingBoxes,) - def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: + def transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes: return F.clamp_bounding_boxes(inpt) # type: ignore[return-value] diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 93198f0009d..d38a6ad8767 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -14,7 +14,7 @@ # TODO: do we want/need to expose this? class Identity(Transform): - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt @@ -34,7 +34,7 @@ def __init__(self, lambd: Callable[[Any], Any], *types: Type): self.lambd = lambd self.types = types or self._transformed_types - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(inpt, self.types): return self.lambd(inpt) else: @@ -99,11 +99,11 @@ def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tenso self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector - def _check_inputs(self, sample: Any) -> Any: + def check_inputs(self, sample: Any) -> Any: if has_any(sample, PIL.Image.Image): raise TypeError(f"{type(self).__name__}() does not support PIL images.") - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: shape = inpt.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: @@ -157,11 +157,11 @@ def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = self.std = list(std) self.inplace = inplace - def _check_inputs(self, sample: Any) -> Any: + def check_inputs(self, sample: Any) -> Any: if has_any(sample, PIL.Image.Image): raise TypeError(f"{type(self).__name__}() does not support PIL images.") - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace) @@ -197,11 +197,11 @@ def __init__( if not 0.0 < self.sigma[0] <= self.sigma[1]: raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}") - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() return dict(sigma=[sigma, sigma]) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params) @@ -228,7 +228,7 @@ def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None: self.sigma = sigma self.clip = clip - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, sigma=self.sigma, clip=self.clip) @@ -272,7 +272,7 @@ def __init__( self.dtype = dtype self.scale = scale - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(self.dtype, torch.dtype): # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype # is a simple torch.dtype @@ -335,7 +335,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.to_dtype, inpt, dtype=self.dtype, scale=True) @@ -432,11 +432,11 @@ def forward(self, *inputs: Any) -> Any: ) params = dict(valid=valid, labels=labels) - flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs] + flat_outputs = [self.transform(inpt, params) for inpt in flat_inputs] return tree_unflatten(flat_outputs, spec) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index c59d5078d46..687b50188a8 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -22,5 +22,5 @@ def __init__(self, num_samples: int): super().__init__() self.num_samples = num_samples - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.uniform_temporal_subsample, inpt, self.num_samples) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index b7eced5a287..5f274589709 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -15,6 +15,11 @@ class Transform(nn.Module): + """Base class to implement your own v2 transforms. + + See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py` for + more details. + """ # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. @@ -24,31 +29,44 @@ def __init__(self) -> None: super().__init__() _log_api_usage_once(self) - def _check_inputs(self, flat_inputs: List[Any]) -> None: + def check_inputs(self, flat_inputs: List[Any]) -> None: pass - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + # When v2 was introduced, this method was private and called + # `_get_params()`. Now it's publicly exposed as `make_params()`. It cannot + # be exposed as `get_params()` because there is already a `get_params()` + # methods for v2 transforms: it's the v1's `get_params()` that we have to + # keep in order to guarantee 100% BC with v1. (It's defined in + # __init_subclass__ below). + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + """Method to override for custom transforms. + + See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`""" return dict() def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: kernel = _get_kernel(functional, type(inpt), allow_passthrough=True) return kernel(inpt, *args, **kwargs) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + """Method to override for custom transforms. + + See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`""" raise NotImplementedError def forward(self, *inputs: Any) -> Any: + """Do not override this! Use ``transform()`` instead.""" flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) - self._check_inputs(flat_inputs) + self.check_inputs(flat_inputs) needs_transform_list = self._needs_transform_list(flat_inputs) - params = self._get_params( + params = self.make_params( [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] ) flat_outputs = [ - self._transform(inpt, params) if needs_transform else inpt + self.transform(inpt, params) if needs_transform else inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) ] @@ -153,23 +171,23 @@ def __init__(self, p: float = 0.5) -> None: def forward(self, *inputs: Any) -> Any: # We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return # early afterwards in case the random check triggers. The same result could be achieved by calling - # `super().forward()` after the random check, but that would call `self._check_inputs` twice. + # `super().forward()` after the random check, but that would call `self.check_inputs` twice. inputs = inputs if len(inputs) > 1 else inputs[0] flat_inputs, spec = tree_flatten(inputs) - self._check_inputs(flat_inputs) + self.check_inputs(flat_inputs) if torch.rand(1) >= self.p: return inputs needs_transform_list = self._needs_transform_list(flat_inputs) - params = self._get_params( + params = self.make_params( [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform] ) flat_outputs = [ - self._transform(inpt, params) if needs_transform else inpt + self.transform(inpt, params) if needs_transform else inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) ] diff --git a/torchvision/transforms/v2/_type_conversion.py b/torchvision/transforms/v2/_type_conversion.py index 7c7439b1d02..bf9f7185239 100644 --- a/torchvision/transforms/v2/_type_conversion.py +++ b/torchvision/transforms/v2/_type_conversion.py @@ -20,7 +20,7 @@ class PILToTensor(Transform): _transformed_types = (PIL.Image.Image,) - def _transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Tensor: return F.pil_to_tensor(inpt) @@ -33,7 +33,7 @@ class ToImage(Transform): _transformed_types = (is_pure_tensor, PIL.Image.Image, np.ndarray) - def _transform( + def transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] ) -> tv_tensors.Image: return F.to_image(inpt) @@ -66,7 +66,7 @@ def __init__(self, mode: Optional[str] = None) -> None: super().__init__() self.mode = mode - def _transform( + def transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] ) -> PIL.Image.Image: return F.to_pil_image(inpt, mode=self.mode) @@ -80,5 +80,5 @@ class ToPureTensor(Transform): _transformed_types = (tv_tensors.TVTensor,) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: + def transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: return inpt.as_subclass(torch.Tensor) diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index e7cde4c5c33..dd65ca4d9c9 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -151,6 +151,10 @@ def _parse_labels_getter(labels_getter: Union[str, Callable[[Any], Any], None]) def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes: + """Return the Bounding Boxes in the input. + + Assumes only one ``BoundingBoxes`` object is present. + """ # This assumes there is only one bbox per sample as per the general convention try: return next(inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.BoundingBoxes)) @@ -159,6 +163,7 @@ def get_bounding_boxes(flat_inputs: List[Any]) -> tv_tensors.BoundingBoxes: def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: + """Return Channel, Height, and Width.""" chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs @@ -173,6 +178,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: + """Return Height and Width.""" sizes = { tuple(get_size(inpt)) for inpt in flat_inputs diff --git a/version.txt b/version.txt index d7f91f260c5..c6241d3d941 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.20.0a0 +0.22.0a0