Skip to content

Commit

Permalink
AddOp(upsample_bicubic2d) | feat(torchlib) (#1208)
Browse files Browse the repository at this point in the history
Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
xiaowuhu and justinchuby authored Jan 3, 2024
1 parent 1fa1ed6 commit 1231cc0
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 6 deletions.
82 changes: 76 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,16 +2197,86 @@ def aten_unflatten_dense_tensors(
raise NotImplementedError()


@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bicubic2d.vec"), trace_only=True)
def aten_upsample_bicubic2d(
self: TensorType,
self: TReal,
output_size: INT64,
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TensorType:
"""upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
scale_factors: Optional[TFloat] = None,
) -> TReal:
"""upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
"""

raise NotImplementedError()
if output_size is not None:
result = _aten_upsample_output_size(self, output_size, align_corners, "cubic")
else:
result = _aten_upsample_scales(self, scale_factors, align_corners, "cubic")
return result


@torch_op("aten::upsample_bicubic2d", private=True)
def _aten_upsample_output_size(
self: TReal,
output_size: INT64,
align_corners: bool,
str_mode: str,
) -> TReal:
self_shape = op.Shape(self)
starts = op.Constant(value_ints=[0])
ends = op.Constant(value_ints=[2])
batch_channel = op.Slice(self_shape, starts, ends)
output_size = op.Concat(batch_channel, output_size, axis=0)
if align_corners:
result = op.Resize(
self,
None,
None,
output_size,
mode=str_mode,
coordinate_transformation_mode="align_corners",
)
else:
result = op.Resize(
self,
None,
None,
output_size,
mode=str_mode,
coordinate_transformation_mode="pytorch_half_pixel",
)

return result


@torch_op("aten::upsample_bicubic2d", private=True)
def _aten_upsample_scales(
self: TReal,
scale_factors: TFloat,
align_corners: bool,
str_mode: str,
) -> TReal:
scale_factors = op.Cast(scale_factors, to=FLOAT.dtype)
scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0)
if align_corners:
result = op.Resize(
self,
None,
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
None,
mode=str_mode,
coordinate_transformation_mode="align_corners",
)
else:
result = op.Resize(
self,
None,
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
None,
mode=str_mode,
coordinate_transformation_mode="pytorch_half_pixel",
)
return result


def aten_upsample_bicubic2d_backward(
Expand Down
58 changes: 58 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,57 @@ def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs):
yield opinfo_core.SampleInput(t, args=(dimension, size, step))


def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

N, C = 2, 3
D = 4
SS = 3
L = 5

align_corners_options = (True, False)
rank = 2

def shape(size, rank, with_batch_channel=True):
if with_batch_channel:
return tuple([N, C] + ([size] * rank))
return tuple([size] * rank)

make_arg = functools.partial(
torch_testing.make_tensor,
device=device,
dtype=dtype,
requires_grad=requires_grad,
low=-1,
high=1,
)

yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True)

for align_corners in align_corners_options:
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)), shape(S, rank, False), align_corners
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
shape(L, rank, False),
align_corners,
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
None, # output_size
align_corners,
(1.7, 1.7), # scaler
)
yield opinfo_core.SampleInput(
make_arg(shape(D, rank)),
None, # if this is None, the scalar must be list
align_corners,
(0.6, 0.6),
)


class _TestParamsMaxPoolEmptyStrideBase:
# Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203
def __init__(self):
Expand Down Expand Up @@ -1874,6 +1925,13 @@ def __init__(self):
sample_inputs_func=sample_inputs_unfold,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_bicubic2d",
aten_name="upsample_bicubic2d",
dtypes=common_dtype.floating_types_and(torch.bfloat16),
sample_inputs_func=sample_inputs_upsample_bicubic2d,
supports_out=False,
),
opinfo_core.OpInfo(
"nn.functional.max_pool1d_with_indices",
aten_name="max_pool1d_with_indices",
Expand Down
5 changes: 5 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2122,6 +2122,11 @@ def _where_input_wrangler(
input_wrangler=_upsample_bilinear2d_input_wrangler,
trace_only=True,
),
TorchLibOpInfo(
"ops.aten.upsample_bicubic2d",
nn_ops.aten_upsample_bicubic2d,
trace_only=True,
),
TorchLibOpInfo(
"nn.functional.upsample_nearest2d",
nn_ops.aten_upsample_nearest2d,
Expand Down

0 comments on commit 1231cc0

Please sign in to comment.