Skip to content

Commit

Permalink
port prototype tests to new utilities (#8022)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Oct 5, 2023
1 parent 67f3ce2 commit 1d646d4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 116 deletions.
82 changes: 0 additions & 82 deletions test/prototype_common_utils.py

This file was deleted.

67 changes: 34 additions & 33 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,42 @@
import collections.abc
import re

import PIL.Image
import pytest
import torch

from common_utils import assert_equal
from common_utils import assert_equal, make_bounding_boxes, make_detection_masks, make_image, make_video

from prototype_common_utils import make_label
from torchvision.prototype import transforms, tv_tensors
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image

from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from transforms_v2_legacy_utils import (
DEFAULT_EXTRA_DIMS,
make_bounding_boxes,
make_detection_mask,
make_image,
make_video,
)

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]

def _parse_categories(categories):
if categories is None:
num_categories = int(torch.randint(1, 11, ()))
elif isinstance(categories, int):
num_categories = categories
categories = [f"category{idx}" for idx in range(num_categories)]
elif isinstance(categories, collections.abc.Sequence) and all(isinstance(category, str) for category in categories):
categories = list(categories)
num_categories = len(categories)
else:
raise pytest.UsageError(
f"`categories` can either be `None` (default), an integer, or a sequence of strings, "
f"but got '{categories}' instead."
)
return categories, num_categories

def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
[
pytest.param(
transform,
input,
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
)
for transform, inputs in transforms_with_inputs
for idx, input in enumerate(inputs)
],
)

def make_label(*, extra_dims=(), categories=10, dtype=torch.int64, device="cpu"):
categories, num_categories = _parse_categories(categories)
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(extra_dims, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return tv_tensors.Label(data, categories=categories)


class TestSimpleCopyPaste:
Expand Down Expand Up @@ -167,7 +168,7 @@ def test__get_params(self, mocker):

flat_inputs = [
make_image(size=canvas_size, color_space="RGB"),
make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=batch_shape),
make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_shape[0]),
]
params = transform._get_params(flat_inputs)

Expand Down Expand Up @@ -203,9 +204,9 @@ def test__transform_culling(self, mocker):
)

bounding_boxes = make_bounding_boxes(
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_size
)
masks = make_detection_mask(size=canvas_size, batch_dims=(batch_size,))
masks = make_detection_masks(size=canvas_size, num_masks=batch_size)
labels = make_label(extra_dims=(batch_size,))

transform = transforms.FixedSizeCrop((-1, -1))
Expand Down Expand Up @@ -241,7 +242,7 @@ def test__transform_bounding_boxes_clamping(self, mocker):
)

bounding_boxes = make_bounding_boxes(
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_size
)
mock = mocker.patch(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes", wraps=clamp_bounding_boxes
Expand Down Expand Up @@ -389,27 +390,27 @@ def make_tv_tensors():

pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", num_boxes=num_objects, dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
"masks": make_detection_masks(size=size, num_masks=num_objects, dtype=torch.long),
}

yield (pil_image, target)

tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", num_boxes=num_objects, dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
"masks": make_detection_masks(size=size, num_masks=num_objects, dtype=torch.long),
}

yield (tensor_image, target)

tv_tensor_image = make_image(size=size, color_space="RGB")
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", num_boxes=num_objects, dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
"masks": make_detection_masks(size=size, num_masks=num_objects, dtype=torch.long),
}

yield (tv_tensor_image, target)
Expand Down
1 change: 0 additions & 1 deletion test/transforms_v2_legacy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
The following legacy modules depend on this module
- test_transforms_v2_consistency.py
- test_prototype_transforms.py
"""

import collections.abc
Expand Down

0 comments on commit 1d646d4

Please sign in to comment.