From 566e8ca45da78d1fd0eb813065d89b83c2af7375 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 2 Feb 2024 00:22:00 +0000 Subject: [PATCH 1/3] first attempt --- .../python/tools/symbolic_shape_infer.py | 1 + .../ortmodule/_custom_gradient_registry.py | 6 ++++- .../ortmodule/_custom_op_symbolic_registry.py | 12 +++++++++ .../python/orttraining_test_ortmodule_api.py | 26 +++++++++++++++++++ 4 files changed, 44 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9823e8264e17b..251d41a24ccc7 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -240,6 +240,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "upsample_nearest1d": self._infer_aten_upsample, "upsample_nearest2d": self._infer_aten_upsample, "upsample_nearest3d": self._infer_aten_upsample, + "upsample_bicubic2d": self._infer_aten_upsample, } self.run_ = True self.suggested_merge_ = {} diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 77317242727b4..6c62f651fa808 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -241,7 +241,7 @@ def native_group_norm_gradient(): # are available for all versions, though they are not that convienent to use. def _upsample_gradient(backward_fn, dims): scales = ["" for _ in range(dims)] - if "bilinear" in backward_fn: + if "bicubic" in backward_fn: scales = ["I(2)", *scales] return [ ("Shape", ["I(0)"], ["Shape_X"]), @@ -271,3 +271,7 @@ def upsample_nearest2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec") def upsample_nearest3d_gradient(): return _upsample_gradient("upsample_nearest3d_backward", 3) + +@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec") +def upsample_bicubic2d_gradient(): + return _upsample_gradient("upsample_bicubic2d_backward", 2) \ No newline at end of file diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 99e8851b6a697..9c2ea978b9b92 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -808,3 +808,15 @@ def upsample_nearest2d(g, input, output_size, scale_factors): @register_symbolic("upsample_nearest3d") def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") + +@register_symbolic("upsample_bicubic2d") +def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors): + return g.op( + "org.pytorch.aten::ATen", + input, + output_size, + align_corners, + scale_factors, + operator_s="upsample_bicubic2d", + overload_name_s="vec", + ) \ No newline at end of file diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 938d33cc9a714..af84fb62eb4ad 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1804,6 +1804,32 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) +def test_aten_upsample_bicubic(): + class _NeuralNetUpsampleBicubic(torch.nn.Module): + def __init__(self): + super(_NeuralNetUpsampleBicubic, self).__init__() + + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(8, 12), mode="bicubic") + + device = "cuda" + pt_model = _NeuralNetUpsampleBicubic().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input): + prediction = model(input) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + pt_input = torch.randn([2, 4, 6, 8], dtype=torch.float, device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) def test_gradient_correctness_cast_chain(): class NeuralNetCast(torch.nn.Module): From 6ef9d444e868b2159c76aa8a1a34d2712b546c7b Mon Sep 17 00:00:00 2001 From: root Date: Fri, 2 Feb 2024 18:32:26 +0000 Subject: [PATCH 2/3] lint --- .../python/training/ortmodule/_custom_gradient_registry.py | 3 ++- .../python/training/ortmodule/_custom_op_symbolic_registry.py | 3 ++- .../orttraining/test/python/orttraining_test_ortmodule_api.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 6c62f651fa808..4883075112dcb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -272,6 +272,7 @@ def upsample_nearest2d_gradient(): def upsample_nearest3d_gradient(): return _upsample_gradient("upsample_nearest3d_backward", 3) + @register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec") def upsample_bicubic2d_gradient(): - return _upsample_gradient("upsample_bicubic2d_backward", 2) \ No newline at end of file + return _upsample_gradient("upsample_bicubic2d_backward", 2) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 9c2ea978b9b92..9288027f0188c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -809,6 +809,7 @@ def upsample_nearest2d(g, input, output_size, scale_factors): def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") + @register_symbolic("upsample_bicubic2d") def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors): return g.op( @@ -819,4 +820,4 @@ def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors): scale_factors, operator_s="upsample_bicubic2d", overload_name_s="vec", - ) \ No newline at end of file + ) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index af84fb62eb4ad..a64c298cc1fbd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1804,6 +1804,7 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + def test_aten_upsample_bicubic(): class _NeuralNetUpsampleBicubic(torch.nn.Module): def __init__(self): @@ -1831,6 +1832,7 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + def test_gradient_correctness_cast_chain(): class NeuralNetCast(torch.nn.Module): def __init__(self, D): From e840a0f3a2c0f2d005a7862ae333af316f51112f Mon Sep 17 00:00:00 2001 From: root Date: Fri, 2 Feb 2024 19:02:12 +0000 Subject: [PATCH 3/3] more lint --- .../orttraining/test/python/orttraining_test_ortmodule_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index a64c298cc1fbd..6a6832e06330a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1808,7 +1808,7 @@ def run_step(model, input): def test_aten_upsample_bicubic(): class _NeuralNetUpsampleBicubic(torch.nn.Module): def __init__(self): - super(_NeuralNetUpsampleBicubic, self).__init__() + super().__init__() def forward(self, input): return torch.nn.functional.interpolate(input, size=(8, 12), mode="bicubic")