diff --git a/gluoncv/data/transforms/block.py b/gluoncv/data/transforms/block.py index 5965ed93b7..812eae9e0a 100644 --- a/gluoncv/data/transforms/block.py +++ b/gluoncv/data/transforms/block.py @@ -53,7 +53,7 @@ class RandomCrop(Block): Inputs: - **data**: input tensor with (Hi x Wi x C) shape. Outputs: - - **out**: output tensor with ((H+2*pad) x (W+2*pad) x C) shape. + - **out**: output tensor with (size[0] x size[1] x C) or (size x size x C) shape. """ def __init__(self, size, pad=None, interpolation=2): @@ -62,18 +62,13 @@ def __init__(self, size, pad=None, interpolation=2): if isinstance(size, numeric_types): size = (size, size) self._args = (size, interpolation) - if isinstance(pad, int): - self.pad = ((pad, pad), (pad, pad), (0, 0)) - else: - self.pad = pad - + self.pad = ((pad, pad), (pad, pad), (0, 0)) if isinstance(pad, int) else pad def forward(self, x): if self.pad: - x_pad = np.pad(x.asnumpy(), self.pad, - mode='constant', constant_values=0) - - return image.random_crop(nd.array(x_pad), *self._args)[0] - + return image.random_crop(nd.array( + np.pad(x.asnumpy(), self.pad, mode='constant', constant_values=0)), *self._args)[0] + else: + return image.random_crop(x, *self._args)[0] class RandomErasing(Block): """Randomly erasing the area in `src` between `s_min` and `s_max` with `probability`.