Skip to content

Commit

Permalink
Improved docs and tests for (pytorch#2371)
Browse files Browse the repository at this point in the history
- RandomCrop: crop with padding using all commonly supported modes
  • Loading branch information
vfdev-5 authored and de-vri-es committed Aug 4, 2020
1 parent 8e936ba commit 06c7f0a
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 29 deletions.
56 changes: 30 additions & 26 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,29 @@ def _test_functional_geom_op(self, func, fn_kwargs):
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
def _test_class_geom_op(self, method, meth_kwargs=None):
if meth_kwargs is None:
meth_kwargs = {}

tensor, pil_img = self._create_data(height=10, width=10)
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
# test for class interface
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)

# set seed to reproduce the same transformation for tensor and PIL image
torch.manual_seed(12)
transformed_tensor = f(tensor)
torch.manual_seed(12)
transformed_pil_img = f(pil_img)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

scripted_fn = torch.jit.script(getattr(F, func))
transformed_tensor_script = scripted_fn(tensor, **fn_kwargs)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))

# test for class interface
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_geom_op(func, fn_kwargs)
self._test_class_geom_op(method, meth_kwargs)

def test_random_horizontal_flip(self):
self._test_geom_op('hflip', 'RandomHorizontalFlip')
Expand Down Expand Up @@ -107,21 +112,20 @@ def test_crop(self):
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
# Test torchscript of transforms.RandomCrop with size as int
f = T.RandomCrop(size=5)
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

# Test torchscript of transforms.RandomCrop with size as [int, ]
f = T.RandomCrop(size=[5, ], padding=[2, ])
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

# Test torchscript of transforms.RandomCrop with size as list
f = T.RandomCrop(size=[6, 6])
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
sizes = [5, [5, ], [6, 6]]
padding_configs = [
{"padding_mode": "constant", "fill": 0},
{"padding_mode": "constant", "fill": 10},
{"padding_mode": "constant", "fill": 20},
{"padding_mode": "edge"},
{"padding_mode": "reflect"},
]

for size in sizes:
for padding_config in padding_configs:
config = dict(padding_config)
config["size"] = size
self._test_class_geom_op("RandomCrop", config)

def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)}
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant. Only int value is supported for Tensors.
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
Only "constant" is supported for Tensors as of now.
Mode symmetric is not yet supported for Tensor inputs.
- constant: pads with a constant value, this value is specified with fill
Expand Down
3 changes: 2 additions & 1 deletion torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
list of length 1: ``[padding, ]``.
fill (int): Pixel fill value for constant fill. Default is 0.
This value is only used when the padding_mode is constant
padding_mode (str): Type of padding. Only "constant" is supported for Tensors as of now.
padding_mode (str): Type of padding. Should be: constant, edge or reflect. Default is constant.
Mode symmetric is not yet supported for Tensor inputs.
- constant: pads with a constant value, this value is specified with fill
Expand Down
3 changes: 2 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class Pad(torch.nn.Module):
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant. Only "constant" is supported for Tensors as of now.
Default is constant. Mode symmetric is not yet supported for Tensor inputs.
- constant: pads with a constant value, this value is specified with fill
Expand Down Expand Up @@ -469,6 +469,7 @@ class RandomCrop(torch.nn.Module):
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
Mode symmetric is not yet supported for Tensor inputs.
- constant: pads with a constant value, this value is specified with fill
Expand Down

0 comments on commit 06c7f0a

Please sign in to comment.