Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
bhack authored Dec 18, 2024
2 parents 1ec78df + f7b1cfa commit 78062c0
Show file tree
Hide file tree
Showing 37 changed files with 471 additions and 589 deletions.
5 changes: 5 additions & 0 deletions .github/scripts/setup-env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ echo '::group::Install TorchVision'
python setup.py develop
echo '::endgroup::'

echo '::group::Install torchvision-extra-decoders'
# This can be done after torchvision was built
pip install torchvision-extra-decoders
echo '::endgroup::'

echo '::group::Collect environment information'
conda list
python -m torch.utils.collect_env
Expand Down
15 changes: 9 additions & 6 deletions docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ images and videos.
Image Decoding
--------------

Torchvision currently supports decoding JPEG, PNG, WEBP and GIF images. JPEG
decoding can also be done on CUDA GPUs.
Torchvision currently supports decoding JPEG, PNG, WEBP, GIF, AVIF, and HEIC
images. JPEG decoding can also be done on CUDA GPUs.

The main entry point is the :func:`~torchvision.io.decode_image` function, which
you can use as an alternative to ``PIL.Image.open()``. It will decode images
Expand All @@ -30,9 +30,10 @@ run transforms/preproc natively on tensors.
:func:`~torchvision.io.decode_image` will automatically detect the image format,
and call the corresponding decoder. You can also use the lower-level
format-specific decoders which can be more powerful, e.g. if you want to
encode/decode JPEGs on CUDA.
and call the corresponding decoder (except for HEIC and AVIF images, see details
in :func:`~torchvision.io.decode_avif` and :func:`~torchvision.io.decode_heic`).
You can also use the lower-level format-specific decoders which can be more
powerful, e.g. if you want to encode/decode JPEGs on CUDA.

.. autosummary::
:toctree: generated/
Expand All @@ -41,8 +42,10 @@ encode/decode JPEGs on CUDA.
decode_image
decode_jpeg
encode_png
decode_gif
decode_webp
decode_avif
decode_heic
decode_gif

.. autosummary::
:toctree: generated/
Expand Down
9 changes: 9 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,20 @@ are combining pairs of images together. These can be used after the dataloader
Developer tools
^^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:template: class.rst

v2.Transform

.. autosummary::
:toctree: generated/
:template: function.rst

v2.functional.register_kernel
v2.query_size
v2.query_chw
v2.get_bounding_boxes


V1 API Reference
Expand Down
117 changes: 98 additions & 19 deletions gallery/transforms/plot_custom_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
"""

# %%
from typing import Any, Dict, List

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
Expand Down Expand Up @@ -89,33 +91,110 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
# A key feature of the builtin Torchvision V2 transforms is that they can accept
# arbitrary input structure and return the same structure as output (with
# transformed entries). For example, transforms can accept a single image, or a
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input:
# tuple of ``(img, label)``, or an arbitrary nested dictionary as input. Here's
# an example on the built-in transform :class:`~torchvision.transforms.v2.RandomHorizontalFlip`:

structured_input = {
"img": img,
"annotations": (bboxes, label),
"something_that_will_be_ignored": (1, "hello")
"something that will be ignored": (1, "hello"),
"another tensor that is ignored": torch.arange(10),
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something_that_will_be_ignored"] == (1, "hello")
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")

# %%
# Basics: override the `transform()` method
# -----------------------------------------
#
# In order to support arbitrary inputs in your custom transform, you will need
# to inherit from :class:`~torchvision.transforms.v2.Transform` and override the
# `.transform()` method (not the `forward()` method!). Below is a basic example:


class MyCustomTransform(v2.Transform):
def transform(self, inpt: Any, params: Dict[str, Any]):
if type(inpt) == torch.Tensor:
print(f"I'm transforming an image of shape {inpt.shape}")
return inpt + 1 # dummy transformation
elif isinstance(inpt, tv_tensors.BoundingBoxes):
print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
return tv_tensors.wrap(inpt + 100, like=inpt) # dummy transformation


my_custom_transform = MyCustomTransform()
structured_output = my_custom_transform(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")

# %%
# If you want to reproduce this behavior in your own transform, we invite you to
# look at our `code
# <https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/_transform.py>`_
# and adapt it to your needs.
#
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all TVTensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
#
# We do not provide public dev-facing tools to achieve that at this time, but if
# this is something that would be valuable to you, please let us know by opening
# an issue on our `GitHub repo <https://github.com/pytorch/vision/issues>`_.
# An important thing to note is that when we call ``my_custom_transform`` on
# ``structured_input``, the input is flattened and then each individual part is
# passed to ``transform()``. That is, ``transform()``` receives the input image,
# then the bounding boxes, etc. Within ``transform()``, you can decide how to
# transform each input, based on their type.
#
# If you're curious why the other tensor (``torch.arange()``) didn't get passed
# to ``transform()``, see :ref:`this note <passthrough_heuristic>` for more
# details.
#
# Advanced: The ``make_params()`` method
# --------------------------------------
#
# The ``make_params()`` method is called internally before calling
# ``transform()`` on each input. This is typically useful to generate random
# parameter values. In the example below, we use it to randomly apply the
# transformation with a probability of 0.5


class MyRandomTransform(MyCustomTransform):
def __init__(self, p=0.5):
self.p = p
super().__init__()

def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
apply_transform = (torch.rand(size=(1,)) < self.p).item()
params = dict(apply_transform=apply_transform)
return params

def transform(self, inpt: Any, params: Dict[str, Any]):
if not params["apply_transform"]:
print("Not transforming anything!")
return inpt
else:
return super().transform(inpt, params)


my_random_transform = MyRandomTransform()

torch.manual_seed(0)
_ = my_random_transform(structured_input) # transforms
_ = my_random_transform(structured_input) # doesn't transform

# %%
#
# .. note::
#
# It's important for such random parameter generation to happen within
# ``make_params()`` and not within ``transform()``, so that for a given
# transform call, the same RNG applies to all the inputs in the same way. If
# we were to perform the RNG within ``transform()``, we would risk e.g.
# transforming the image while *not* transforming the bounding boxes.
#
# The ``make_params()`` method takes the list of all the inputs as parameter
# (each of the elements in this list will later be pased to ``transform()``).
# You can use ``flat_inputs`` to e.g. figure out the dimensions on the input,
# using :func:`~torchvision.transforms.v2.query_chw` or
# :func:`~torchvision.transforms.v2.query_size`.
#
# ``make_params()`` should return a dict (or actually, anything you want) that
# will then be passed to ``transform()``.
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ ignore_errors = True

ignore_errors=True

[mypy-torchvision.models.maxvit.*]

ignore_errors=True

[mypy-torchvision.models.detection.anchor_utils]

ignore_errors = True
Expand Down
2 changes: 2 additions & 0 deletions packaging/post_build_script.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
#!/bin/bash
LD_LIBRARY_PATH="/usr/local/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH" python packaging/wheel/relocate.py

pip install torchvision-extra-decoders
4 changes: 2 additions & 2 deletions references/segmentation/v2_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def __init__(self, size, fill=0):
self.size = size
self.fill = v2._utils._setup_fill_arg(fill)

def _get_params(self, sample):
def make_params(self, sample):
_, height, width = v2._utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)

def _transform(self, inpt, params):
def transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

Expand Down
34 changes: 0 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "0") == "1" # TODO enable by default!
USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default!
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
# Note: the GPU video decoding stuff used to be called "video codec", which
Expand Down Expand Up @@ -51,8 +49,6 @@
print(f"{USE_PNG = }")
print(f"{USE_JPEG = }")
print(f"{USE_WEBP = }")
print(f"{USE_HEIC = }")
print(f"{USE_AVIF = }")
print(f"{USE_NVJPEG = }")
print(f"{NVCC_FLAGS = }")
print(f"{USE_CPU_VIDEO_DECODER = }")
Expand Down Expand Up @@ -336,36 +332,6 @@ def make_image_extension():
else:
warnings.warn("Building torchvision without WEBP support")

if USE_HEIC:
heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif.h")
if heic_found:
print("Building torchvision with HEIC support")
print(f"{heic_include_dir = }")
print(f"{heic_library_dir = }")
if heic_include_dir is not None and heic_library_dir is not None:
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
include_dirs.append(heic_include_dir)
library_dirs.append(heic_library_dir)
libraries.append("heif")
define_macros += [("HEIC_FOUND", 1)]
else:
warnings.warn("Building torchvision without HEIC support")

if USE_AVIF:
avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h")
if avif_found:
print("Building torchvision with AVIF support")
print(f"{avif_include_dir = }")
print(f"{avif_library_dir = }")
if avif_include_dir is not None and avif_library_dir is not None:
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
include_dirs.append(avif_include_dir)
library_dirs.append(avif_library_dir)
libraries.append("avif")
define_macros += [("AVIF_FOUND", 1)]
else:
warnings.warn("Building torchvision without AVIF support")

if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA):
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

Expand Down
35 changes: 34 additions & 1 deletion test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import torchvision
from torchvision.io import decode_image, decode_jpeg, decode_webp, read_file
from torchvision.io import decode_avif, decode_heic, decode_image, decode_jpeg, read_file
from torchvision.models import resnet50, ResNet50_Weights


Expand All @@ -24,13 +24,46 @@ def smoke_test_torchvision_read_decode() -> None:
img_jpg = decode_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")

img_png = decode_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.shape != (4, 471, 354):
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")

img_webp = decode_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
if img_webp.shape != (3, 100, 100):
raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}")

if sys.platform == "linux":
pass
# TODO: Fix/uncomment below (the TODO below is mostly accurate but we're
# still observing some failures on some CUDA jobs. Most are working.)
# if torch.cuda.is_available():
# # TODO: For whatever reason this only passes on the runners that
# # support CUDA.
# # Strangely, on the CPU runners where this fails, the AVIF/HEIC
# # tests (ran with pytest) are passing. This is likely related to a
# # libcxx symbol thing, and the proper libstdc++.so get loaded only
# # with pytest? Ugh.
# img_avif = decode_avif(read_file(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.avif")))
# if img_avif.shape != (3, 100, 100):
# raise RuntimeError(f"Unexpected shape of img_avif: {img_avif.shape}")

# img_heic = decode_heic(
# read_file(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic"))
# )
# if img_heic.shape != (3, 100, 100):
# raise RuntimeError(f"Unexpected shape of img_heic: {img_heic.shape}")
else:
try:
decode_avif(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.avif"))
except RuntimeError as e:
assert "torchvision-extra-decoders" in str(e)

try:
decode_heic(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic"))
except RuntimeError as e:
assert "torchvision-extra-decoders" in str(e)


def smoke_test_torchvision_decode_jpeg(device: str = "cpu"):
img_jpg_data = read_file(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
Expand Down
Loading

0 comments on commit 78062c0

Please sign in to comment.