From a43c57f59db35d8cceef1ff8f44985d745eb94b4 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Fri, 20 Oct 2023 11:39:57 -0700 Subject: [PATCH 01/24] ResizeGrad CUDA/ROCM kernel implementation (#17772) --- .../python/tools/symbolic_shape_infer.py | 1 - .../core/graph/gradient_builder.cc | 8 + .../orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../core/graph/training_op_defs.cc | 20 ++ .../ortmodule/_custom_gradient_registry.py | 5 - .../ortmodule/_custom_op_symbolic_registry.py | 13 - .../test/gradient/gradient_ops_test.cc | 35 +++ .../python/orttraining_test_ortmodule_api.py | 8 +- .../training_ops/cuda/resize_grad_test.cc | 227 ++++++++++++++++++ .../cuda/cuda_training_kernels.cc | 12 +- .../training_ops/cuda/tensor/resize_grad.cc | 81 +++++++ .../training_ops/cuda/tensor/resize_grad.h | 41 ++++ .../cuda/tensor/resize_grad_impl.cu | 151 ++++++++++++ .../cuda/tensor/resize_grad_impl.h | 20 ++ .../rocm/rocm_training_kernels.cc | 6 + 16 files changed, 605 insertions(+), 25 deletions(-) create mode 100644 orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu create mode 100644 orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 6d954bd540718..67e9f1b55e9ae 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -230,7 +230,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "upsample_nearest1d": self._infer_aten_upsample, "upsample_nearest2d": self._infer_aten_upsample, "upsample_nearest3d": self._infer_aten_upsample, - "upsample_bilinear2d": self._infer_aten_upsample, } self.run_ = True self.suggested_merge_ = {} diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 133cab71f2b1c..6547f53a3c2ae 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2147,5 +2147,13 @@ IMPLEMENT_GRADIENT_BUILDER(GetScaledSumGradient) { ORT_THROW("ScaledSum gradient builder does not support ", input_count, " inputs"); } +IMPLEMENT_GRADIENT_BUILDER(GetResizeGradient) { + return std::vector{ + NodeDef(OpDef{"ResizeGrad", kMSDomain, 1}, + {GO(0), I(0), I(1), I(2)}, + {GI(0)}, + SrcNodeAttributes())}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index a517e8af13fcc..28a316261e2f6 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -90,6 +90,7 @@ DECLARE_GRADIENT_BUILDER(GetGRUGradient) DECLARE_GRADIENT_BUILDER(GetReciprocalGradient) DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient) DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient) +DECLARE_GRADIENT_BUILDER(GetResizeGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 4062b5d097394..4b8c68aef078a 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -122,6 +122,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient); REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient); REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient); + REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index cfc79455c43ed..c90acfdb7bb78 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -5001,6 +5001,26 @@ Return true if all elements are true and false otherwise. "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(ResizeGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "dY", "Gradient of output Y.", "T") + .Input(1, "X", "Input tensor to the Resize operator.", "T") + .Input(2, "roi", "The roi input to the Resize operator.", "T", OpSchema::Optional) + .Input(3, "scales", "The scales input to the Resize operator.", "tensor(float)", OpSchema::Optional) + .Output(0, "dX", "Gradient of the input X.", "T") + .AllowUncheckedAttributes() + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); + if (hasInputShape(ctx, 1)) { + propagateShapeFromInputToOutput(ctx, 1, 0); + } + }); } } // namespace training diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 156c3e001d88f..77317242727b4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -271,8 +271,3 @@ def upsample_nearest2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec") def upsample_nearest3d_gradient(): return _upsample_gradient("upsample_nearest3d_backward", 3) - - -@register_gradient("org.pytorch.aten", "ATen", "upsample_bilinear2d", "vec") -def upsample_bilinear2d_gradient(): - return _upsample_gradient("upsample_bilinear2d_backward", 2) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 64c7abe1c9386..6e694dcdf2e39 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -808,16 +808,3 @@ def upsample_nearest2d(g, input, output_size, scale_factors): @register_symbolic("upsample_nearest3d") def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") - - -@register_symbolic("upsample_bilinear2d") -def upsample_bilinear2d(g, input, output_size, align_corners, scale_factors): - return g.op( - "org.pytorch.aten::ATen", - input, - output_size, - align_corners, - scale_factors, - operator_s="upsample_bilinear2d", - overload_name_s="vec", - ) diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 597801f4030c1..890a1bbccbc92 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3298,6 +3298,41 @@ TEST(GradientCheckerTest, ConvTransposeGrad) { execution_providers.push_back(DefaultCudaExecutionProvider()); ConvTransposeGradientCheckerTest(&execution_providers); } + +// TODO: Enable test for ROCM +TEST(GradientCheckerTest, ResizeGrad) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + const std::vector attributes = { + MakeAttribute("coordinate_transformation_mode", "half_pixel"), + MakeAttribute("cubic_coeff_a", -0.75f), + MakeAttribute("exclude_outside", static_cast(0)), + MakeAttribute("extrapolation_value", 0.0f), + MakeAttribute("mode", "linear"), + MakeAttribute("nearest_mode", "floor")}; + + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"Resize", kOnnxDomain, 18}; + + TensorInfo x_info({1, 2, 4, 4}, true); + TensorInfo roi_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo scales_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + + TensorInfo y_info({1, 2, 8, 8}, true); + + std::vector> x_datas = {{0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 2.0f, 2.0f}}; + + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, roi_info, scales_info}, + {y_info}, &max_error, x_datas, attributes, true, false, &execution_providers)); + EXPECT_IS_TINY(max_error); +} + #endif // USE_CUDA } // namespace test diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 643d47b0d043e..c8ec2e52f3078 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1773,13 +1773,17 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) -def test_aten_upsample_bilinear(): +@pytest.mark.parametrize("interpolate_size_scale", ({"size": (8, 12)}, {"scale_factor": 4.7})) +@pytest.mark.parametrize("align_corners", (True, False)) +def test_resize_grad_correctness_bilinear_2d(interpolate_size_scale, align_corners): class _NeuralNetUpsampleBilinear(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input): - return torch.nn.functional.interpolate(input, size=(8, 12), mode="bilinear") + return torch.nn.functional.interpolate( + input, align_corners=align_corners, mode="bilinear", **interpolate_size_scale + ) device = "cuda" pt_model = _NeuralNetUpsampleBilinear().to(device) diff --git a/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc new file mode 100644 index 0000000000000..8fc13af8816be --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/providers/compare_provider_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime::test { + +#if defined(USE_CUDA) || defined(USE_ROCM) + +namespace { + +void AddResizeGradAttributes(OpTester& test, const std::string& coordinate_transformation_mode) { + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", coordinate_transformation_mode); +} + +} // namespace + +TEST(ResizeGradTest, ResizeGradWithSizes) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(128, 1.0f); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX(32, 4.0f); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithSizesHalf) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(128, 1.0f); + std::vector dY_half(dY.size()); + ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast(dY.size())); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_half(X.size()); + ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast(X.size())); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX(32, 4.0f); + std::vector dX_half(dX.size()); + ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast(dX.size())); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY_half); + test.AddInput("X", X_shape, X_half); + + test.AddOutput("dX", dX_shape, dX_half); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithSizesAndAlignCorners) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "align_corners"); + + std::vector dY(128, 1.0f); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f, + 3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f, + 2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f, + 3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScales) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(72, 1.0f); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f, + 2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScalesHalf) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(72, 1.0f); + std::vector dY_half(dY.size()); + ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast(dY.size())); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_half(X.size()); + ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast(X.size())); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f, + 2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f}); + std::vector dX_half(dX.size()); + ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast(dX.size())); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY_half); + test.AddInput("X", X_shape, X_half); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX_half); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScalesAndAlignCorners) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "align_corners"); + + std::vector dY(72, 1.0f); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f, + 2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f, + 1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f, + 2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +#endif // defined(USE_CUDA) || defined(USE_ROCM) + +} // namespace onnxruntime::test diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 8e61dbee506f2..ae4f48b6b49a2 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -207,6 +207,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ResizeGrad); // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training @@ -453,13 +456,14 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training #ifdef ENABLE_TRAINING diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc new file mode 100644 index 0000000000000..a5e8f7cd35d88 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "orttraining/training_ops/cuda/tensor/resize_grad.h" +#include "orttraining/training_ops/cuda/tensor/resize_grad_impl.h" + +namespace onnxruntime::cuda { + +#define REGISTER_RESIZEGRAD_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ResizeGrad, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) /* Keep roi on CPU */ \ + .InputMemoryType(OrtMemTypeCPUInput, 3) /* Keep scales on CPU */ \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ResizeGrad); + +REGISTER_RESIZEGRAD_KERNEL_TYPED(MLFloat16) +REGISTER_RESIZEGRAD_KERNEL_TYPED(float) +REGISTER_RESIZEGRAD_KERNEL_TYPED(double) + +template +Status ResizeGrad::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + + const Tensor* dY = context->Input(0); + const Tensor* X = context->Input(1); + const Tensor* scales = context->Input(3); + + ORT_ENFORCE(X->Shape().NumDimensions() == 4, "Expected input tensor to have 4 dimensions. Actual: ", + X->Shape().NumDimensions()); + + const auto get_scales_from_input = [](const Tensor* scales) { + if (nullptr == scales) { + return std::make_pair(std::optional{}, std::optional{}); + } + + ORT_ENFORCE(scales->Shape().Size() == 4, "There must be a scale for each dimension."); + + const auto* scales_data = scales->Data(); + return std::make_pair(std::optional{scales_data[2]}, std::optional{scales_data[3]}); + }; + + std::pair, std::optional> scale_factors = get_scales_from_input(scales); + + Tensor* dX = context->Output(0, X->Shape()); + + const int64_t batch_size = X->Shape()[0]; + const int64_t num_channels = X->Shape()[1]; + const int64_t output_height = dY->Shape()[2]; + const int64_t output_width = dY->Shape()[3]; + const int64_t input_height = X->Shape()[2]; + const int64_t input_width = X->Shape()[3]; + + if (dX->Shape() == dY->Shape()) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dX->MutableDataRaw(), dY->DataRaw(), dY->SizeInBytes(), cudaMemcpyDeviceToDevice)); + return Status::OK(); + } + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(dX->MutableDataRaw(), 0, dX->SizeInBytes(), Stream(context))); + + const bool align_corners = coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS; + const CudaT* dy_data = reinterpret_cast(dY->Data()); + CudaT* dx_data = reinterpret_cast(dX->MutableData()); + + ResizeGradImpl(Stream(context), input_height, input_width, output_height, + output_width, batch_size, num_channels, align_corners, + scale_factors.first, scale_factors.second, + dy_data, dx_data); + + return Status::OK(); +} + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h new file mode 100644 index 0000000000000..53f8d5f0d71f5 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cpu/tensor/upsamplebase.h" + +namespace onnxruntime::cuda { + +template +class ResizeGrad final : public UpsampleBase, public CudaKernel { + public: + ResizeGrad(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { + ORT_ENFORCE(!antialias_, "Antialiasing is not supported in ResizeGrad yet."); + + ORT_ENFORCE(axes_.empty(), "ReizeGrad does not support the `axes` attribute yet."); + + std::string coordinate_transform_mode = + info.GetAttrOrDefault("coordinate_transformation_mode", "half_pixel"); + coordinate_transform_mode_ = StringToCoordinateTransformationMode(coordinate_transform_mode); + ORT_ENFORCE(coordinate_transform_mode_ == ResizeCoordinateTransformationMode::HALF_PIXEL || + coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS, + "ReizeGrad only supports the `HALF_PIXEL` and `ALIGN_CORNERS` coordinate_transform_mode ", + coordinate_transform_mode, " is not supported yet."); + + ORT_ENFORCE(keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH, + "ReizeGrad only supports the `STRETCH` policy."); + + std::string mode; + ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); + ORT_ENFORCE((UpsampleMode::LINEAR == mode_), + "ReizeGrad only supports the `LINEAR` mode. ", mode, " mode is not supported yet."); + } + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu new file mode 100644 index 0000000000000..0507cda62390b --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Contents of this file are derived from the pytorch cuda implementation of +// the upsample_bilinear2d_backward implementation at: +// https://github.com/pytorch/pytorch/blob/ce50132748f652ed6079c3db8008a6817594dbae/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu + +#include "orttraining/training_ops/cuda/tensor/resize_grad_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/atomic/common.cuh" + +namespace onnxruntime::cuda { + +namespace { + +constexpr int NumThreadsPerBlock = GridDim::maxThreadsPerBlock; + +} // namespace + +__device__ __forceinline__ size_t +idx(const size_t nc, + const size_t height, + const size_t width, + const size_t h, + const size_t w) { + return (nc * height + h) * width + w; +} + +template +__device__ __forceinline__ static T AreaPixelComputeSourceIndex( + T scale, + int dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + T src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + return (!cubic && src_idx < static_cast(0)) + ? static_cast(0) + : src_idx; + } +} + +template +__global__ void UpsampleGrad(const int64_t nc, const int64_t input_height, + const int64_t input_width, const int64_t output_height, + const int64_t output_width, const AccT rheight, + const AccT rwidth, const bool align_corners, + const T* dY_data, T* dX_data) { + const size_t dy_numel = nc * output_width * output_height; + const size_t dx_numel = nc * input_width * input_height; + for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; + index < dy_numel; + index += blockDim.x * gridDim.x) { + size_t index_temp = index; + const int w2 = index_temp % output_width; // 0:width2-1 + index_temp /= output_width; + const int h2 = index_temp % output_height; // 0:height2-1 + const size_t nc = index_temp / output_height; + + const AccT h1r = AreaPixelComputeSourceIndex( + rheight, h2, align_corners, /*cubic=*/false); + const int h1 = h1r; + const int h1p = (h1 < input_height - 1) ? 1 : 0; + const AccT h1lambda = h1r - h1; + const AccT h0lambda = static_cast(1) - h1lambda; + + const AccT w1r = AreaPixelComputeSourceIndex( + rwidth, w2, align_corners, /*cubic=*/false); + const int w1 = w1r; + const int w1p = (w1 < input_width - 1) ? 1 : 0; + const AccT w1lambda = w1r - w1; + const AccT w0lambda = static_cast(1) - w1lambda; + + const T d2val = dY_data[index]; + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1, w1), + dx_numel, + static_cast(h0lambda * w0lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1, w1 + w1p), + dx_numel, + static_cast(h0lambda * w1lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1 + h1p, w1), + dx_numel, + static_cast(h1lambda * w0lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1 + h1p, w1 + w1p), + dx_numel, + static_cast(h1lambda * w1lambda) * d2val); + } +} + +template +T AreaPixelComputeScale(int64_t input_size, int64_t output_size, bool align_corners, + const std::optional& scale) { + if (align_corners) { + if (output_size <= 1) { + return T{0}; + } + return static_cast(input_size - 1) / static_cast(output_size - 1); + } else { + if (scale.has_value()) { + return static_cast(T{1.0} / *scale); + } else { + return static_cast(input_size) / static_cast(output_size); + } + } +} + +template +void ResizeGradImpl(cudaStream_t stream, int64_t input_height, + int64_t input_width, int64_t output_height, + int64_t output_width, int64_t batch_size, + int64_t channels, bool align_corners, + const std::optional& scale_height, + const std::optional& scale_width, + const T* dY_data, T* dX_data) { + float rheight = AreaPixelComputeScale(input_height, output_height, align_corners, scale_height); + float rwidth = AreaPixelComputeScale(input_width, output_width, align_corners, scale_width); + + const size_t output_numel = batch_size * channels * output_height * output_width; + int blocks_per_grid = (int)(ceil(static_cast(output_numel) / NumThreadsPerBlock)); + UpsampleGrad<<>>( + batch_size * channels, input_height, input_width, output_height, output_width, + rheight, rwidth, align_corners, dY_data, dX_data); +} + +#define SPECIALIZED_RESIZEGRAD_IMPL(T) \ + template void ResizeGradImpl(cudaStream_t stream, int64_t input_height, \ + int64_t input_width, int64_t output_height, \ + int64_t output_width, int64_t batch_size, \ + int64_t channels, bool align_corners, \ + const std::optional& scale_height, \ + const std::optional& scale_width, \ + const T* dY_data, T* dX_data); + +SPECIALIZED_RESIZEGRAD_IMPL(half) +SPECIALIZED_RESIZEGRAD_IMPL(float) +SPECIALIZED_RESIZEGRAD_IMPL(double) + +#undef SPECIALIZED_RESIZEGRAD_IMPL + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h new file mode 100644 index 0000000000000..3e917f9071e30 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime::cuda { + +template +void ResizeGradImpl(cudaStream_t stream, int64_t input_height, + int64_t input_width, int64_t output_height, + int64_t output_width, int64_t batch_size, + int64_t channels, bool align_corners, + const std::optional& scale_height, + const std::optional& scale_width, + const T* dY_data, T* dX_data); + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 2321aa23dd6eb..e0749c2fb4d0d 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -187,6 +187,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad); #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. @@ -387,6 +390,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI) From 020824ed509893b13aaf6fdc3e651c7a341f7273 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Fri, 20 Oct 2023 15:08:25 -0700 Subject: [PATCH 02/24] Update ONNX to 1.15.0rc1 (#17914) --- cgmanifests/generated/cgmanifest.json | 32 ++++++++++++++++++- cmake/deps.txt | 2 +- cmake/external/onnx | 2 +- js/web/docs/webgl-operators.md | 10 ++++-- onnxruntime/test/onnx/TestCase.cc | 15 +++++++++ .../onnx_backend_test_series_filters.jsonc | 16 +++++++++- .../templates/download-deps.yml | 4 +-- 7 files changed, 73 insertions(+), 8 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 08ca90d7c3b7f..f9f2fbdab7b10 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -2,6 +2,36 @@ "$schema": "https://json.schemastore.org/component-detection-manifest.json", "Version": 1, "Registrations": [ + { + "component": { + "type": "git", + "git": { + "commitHash": "a896e3d066448b3530dbcaa48869fafefd738f57", + "repositoryUrl": "https://github.com/emscripten-core/emsdk.git" + }, + "comments": "git submodule at cmake/external/emsdk" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7a2ed51a6b682a83e345ff49fc4cfd7ca47550db", + "repositoryUrl": "https://github.com/google/libprotobuf-mutator.git" + }, + "comments": "git submodule at cmake/external/libprotobuf-mutator" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40", + "repositoryUrl": "https://github.com/onnx/onnx.git" + }, + "comments": "git submodule at cmake/external/onnx" + } + }, { "component": { "type": "git", @@ -166,7 +196,7 @@ "component": { "type": "git", "git": { - "commitHash": "fdefbe85ed9c362b95b9b401cd19db068a76141f", + "commitHash": "6a20ba82b439ea1fd650da4d389e96b60a1dd828", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "onnx" diff --git a/cmake/deps.txt b/cmake/deps.txt index 7cf49f02333a4..26fd35075c4b9 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -24,7 +24,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/14303de049144035dfd94ace5f7a3b44773b1aad.zip;250eab9690392b248d75b56e605fb49eca373442 +onnx;https://github.com/onnx/onnx/archive/6a20ba82b439ea1fd650da4d389e96b60a1dd828.zip;179a22ad4cd67109c60031ae4b6cf2f434d8bd7e #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa diff --git a/cmake/external/onnx b/cmake/external/onnx index e2525550194ce..6a20ba82b439e 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit e2525550194ce3d8a2c4a3af451c9d9b3ae6650e +Subproject commit 6a20ba82b439ea1fd650da4d389e96b60a1dd828 diff --git a/js/web/docs/webgl-operators.md b/js/web/docs/webgl-operators.md index de84134ddbb3f..7c129b66bfa3d 100644 --- a/js/web/docs/webgl-operators.md +++ b/js/web/docs/webgl-operators.md @@ -12,6 +12,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Acos](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acos) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Acos-7) | | [Acosh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acosh) | | | [Add](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Add) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-14) | +| [AffineGrid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AffineGrid) | | | [And](https://github.com/onnx/onnx/blob/main/docs/Operators.md#And) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#And-7) | | [ArgMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMax) | | | [ArgMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMin) | | @@ -67,6 +68,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Gather](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13) | | [GatherElements](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements) | | | [GatherND](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND) | | +| [Gelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gelu) | | | [Gemm](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-7), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-13) | | [GlobalAveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool) | [1+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GlobalAveragePool-1) | | [GlobalLpPool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalLpPool) | | @@ -82,6 +84,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | | | [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19) | | [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | | +| [ImageDecoder](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ImageDecoder) | | | [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6) | | [IsInf](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf) | | | [IsNaN](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsNaN) | | @@ -137,12 +140,13 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [ReduceL2](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2) | | | [ReduceLogSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-18) | | [ReduceLogSumExp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSumExp) | | -| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18) | +| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-20) | | [ReduceMean](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMean) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-18) | -| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18) | +| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-20) | | [ReduceProd](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceProd) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-18) | | [ReduceSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-11) | | [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) | +| [RegexFullMatch](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RegexFullMatch) | | | [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) | | [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19) | | [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) | @@ -179,7 +183,9 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [SplitToSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SplitToSequence) | | | [Sqrt](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-13) | | [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13) | +| [StringConcat](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringConcat) | | | [StringNormalizer](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringNormalizer) | | +| [StringSplit](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringSplit) | | | [Sub](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sub) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-14) | | [Sum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sum) | [6-7](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-6), [8-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-8), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-13) | | [Tan](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Tan) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tan-7) | diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index bc88f69fa990f..47c3798721679 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -944,6 +944,20 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"simple_rnn_batchwise", "type error", {}}, {"mod_float_mixed_sign_example", "fmod attribute must be true for floating point types", {}}, {"col2im_pads", "result mismatch", {"opset18"}}, + {"gridsample_volumetric_nearest_align_corners_0", "result differs", {}}, + {"gridsample_volumetric_nearest_align_corners_1", "result differs", {}}, + {"reduce_l1_empty_set", "unknown version", {}}, + {"reduce_l1_empty_set_expanded", "unknown version", {}}, + {"reduce_l2_empty_set", "unknown version", {}}, + {"reduce_l2_empty_set_expanded", "unknown version", {}}, + {"reduce_log_sum_empty_set", "unknown version", {}}, + {"reduce_log_sum_empty_set_expanded", "unknown version", {}}, + {"reduce_log_sum_exp_empty_set", "unknown version", {}}, + {"reduce_log_sum_exp_empty_set_expanded", "unknown version", {}}, + {"reduce_prod_empty_set", "unknown version", {}}, + {"reduce_sum_empty_set", "unknown version", {}}, + {"reduce_sum_square_empty_set", "unknown version", {}}, + {"reduce_sum_square_empty_set_expanded", "unknown version", {}}, #ifdef ENABLE_TRAINING_CORE {"adagrad", "not a registered function/op", {}}, // Op not registered. {"adagrad_multiple", "not a registered function/op", {}}, // Op not registered. @@ -1339,6 +1353,7 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"gridsample_reflection_padding", "result differs"}); broken_tests->insert({"spacetodepth", "result differs"}); } + #ifdef DISABLE_CONTRIB_OPS broken_tests->insert({"coreml_SqueezeNet_ImageNet", "This model uses contrib ops."}); broken_tests->insert({"keras2coreml_Permute_ImageNet", "This model uses contrib ops."}); diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index c142106ed506c..b3161a42bb3e5 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -290,7 +290,21 @@ "^test_isnan", "^test_isnan_float16", "^test_reduce_max_bool_inputs", - "^test_reduce_min_bool_inputs" + "^test_reduce_min_bool_inputs", + "^test_reduce_min_empty_set", + "^test_reduce_l1_empty_set", + "^test_reduce_l1_empty_set_expanded", + "^test_reduce_l2_empty_set", + "^test_reduce_l2_empty_set_expanded", + "^test_reduce_log_sum_empty_set", + "^test_reduce_log_sum_empty_set_expanded", + "^test_reduce_log_sum_exp_empty_set", + "^test_reduce_log_sum_exp_empty_set_expanded", + "^test_reduce_prod_empty_set", + "^test_reduce_sum_empty_set", + "^test_reduce_sum_empty_set_non_reduced_axis_zero", + "^test_reduce_sum_square_empty_set", + "^test_reduce_sum_square_empty_set_expanded" ], "current_failing_tests_x86": [ "^test_vgg19", diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index cf73691a5eecc..9ca4a45ffcec4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.90 + version: 1.0.95 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.90 + version: 1.0.95 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 2f57625cb01300b538bce61ea51caffa236b4732 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Fri, 20 Oct 2023 22:09:46 +0000 Subject: [PATCH 03/24] [TensorRT EP] Add stream sync after enqueue (#18026) If the model is partitioned into TRT subgraphs and CUDA EP node, we observed cuda stream synchronization issue when multithreading. Calling stream sync API after enqueue can solve this issue without adding much performance overhead. --- .../providers/tensorrt/tensorrt_execution_provider.cc | 8 +++++++- .../core/providers/tensorrt/tensorrt_execution_provider.h | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 74d237a62f73d..d9238e41a28cc 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1869,6 +1869,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } else if (number_of_trt_nodes == number_of_ort_nodes) { LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { + sync_stream_after_enqueue_ = true; LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } @@ -2387,7 +2388,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, @@ -2415,6 +2416,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; + bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; auto trt_builder = trt_state->builder->get(); @@ -3022,6 +3024,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; + bool sync_stream_after_enqueue = false; OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; @@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; + // Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3() + mutable bool sync_stream_after_enqueue_ = false; + CUDAGraph cuda_graph_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; From 009cd4ea2e0621459806010cea7d7533d0acb39d Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:12:21 -0700 Subject: [PATCH 04/24] Allow cuda custom ops allocate deferred cpu mem (#17893) Expose a new allocator from cuda stream. The allocator manages deferred cpu memory which only get recycled before stream destruction. --------- Co-authored-by: Randy Shuai --- .../core/providers/cuda/cuda_context.h | 31 +++++++++++++++++++ .../core/providers/cuda/cuda_resource.h | 5 +-- .../core/providers/cuda/cuda_stream_handle.cc | 25 ++++++++++++++- .../core/providers/cuda/cuda_stream_handle.h | 10 ++++++ .../custom_op_library/cuda/cuda_ops.cc | 9 +++--- .../custom_op_library/cuda/cuda_ops.h | 10 +++++- .../custom_op_library/custom_op_library.cc | 9 +++++- .../custom_op_library/rocm/rocm_ops.cc | 8 ++--- .../custom_op_library/rocm/rocm_ops.h | 10 +++++- 9 files changed, 100 insertions(+), 17 deletions(-) diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 13c176dad3cc5..646f33ed952a4 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext { cudaStream_t cuda_stream = {}; cudnnHandle_t cudnn_handle = {}; cublasHandle_t cublas_handle = {}; + OrtAllocator* deferred_cpu_allocator = {}; void Init(const OrtKernelContext& kernel_ctx) override { const auto& ort_api = Ort::GetApi(); @@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext { ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } cublas_handle = reinterpret_cast(resource); + + resource = {}; + status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource); + if (status) { + ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + deferred_cpu_allocator = reinterpret_cast(resource); + } + + void* AllocDeferredCpuMem(size_t size) const { + if (0 == size) { + return {}; + } + const auto& ort_api = Ort::GetApi(); + void* mem = {}; + auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem); + if (status) { + ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return mem; + } + + void FreeDeferredCpuMem(void* mem) const { + if (mem) { + const auto& ort_api = Ort::GetApi(); + auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem); + if (status) { + ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + } } }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index e46fc5b4219dd..8c3ed46ade6a1 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -3,10 +3,11 @@ #include "core/providers/resource.h" -#define ORT_CUDA_RESOUCE_VERSION 1 +#define ORT_CUDA_RESOUCE_VERSION 2 enum CudaResource : int { cuda_stream_t = cuda_resource_offset, cudnn_handle_t, - cublas_handle_t + cublas_handle_t, + deferred_cpu_allocator_t, }; \ No newline at end of file diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index e855a515f445a..5f1dbd30f6a3e 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -7,6 +7,25 @@ namespace onnxruntime { +DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = + [](OrtAllocator* this_, size_t size) { + auto self = reinterpret_cast(this_); + return self->cuda_stream_.GetCpuAllocator()->Alloc(size); + }; + OrtAllocator::Free = + [](OrtAllocator* this_, void* p) { + auto self = reinterpret_cast(this_); + self->cuda_stream_.EnqueDeferredCPUBuffer(p); + }; + OrtAllocator::Info = + [](const OrtAllocator* this_) { + auto self = reinterpret_cast(this_); + return &self->cuda_stream_.GetCpuAllocator()->Info(); + }; +} + struct CudaNotification : public synchronize::Notification { CudaNotification(Stream& s) : Notification(s) { CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); @@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream, cublasHandle_t external_cublas_handle) : Stream(stream, device), own_stream_(own_flag), cpu_allocator_(cpu_allocator), - release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) { + release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream), + deferred_cpu_allocator_(*this) { if (own_flag) { CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream)); @@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::cublas_handle_t: return reinterpret_cast(cublas_handle_); break; + case CudaResource::deferred_cpu_allocator_t: + return const_cast(&deferred_cpu_allocator_); + break; default: break; } diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index 9c62b029b7a36..917702fae08f1 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -9,6 +9,13 @@ namespace onnxruntime { +struct CudaStream; + +struct DeferredCpuAllocator : public OrtAllocator { + DeferredCpuAllocator(CudaStream&); + CudaStream& cuda_stream_; +}; + struct CudaStream : Stream { CudaStream(cudaStream_t stream, const OrtDevice& device, @@ -36,10 +43,13 @@ struct CudaStream : Stream { void* GetResource(int version, int id) const override; + onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; bool release_cpu_buffer_on_cuda_stream_{true}; + DeferredCpuAllocator deferred_cpu_allocator_; }; void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, diff --git a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc index aba35b33b75c6..3d561d378cb8c 100644 --- a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUDA +#if defined(USE_CUDA) && !defined(ENABLE_TRAINING) #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" @@ -32,6 +32,9 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx, CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream"); CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle"); CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle"); + void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t)); + CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator"); + cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem); auto z_raw = Z.Allocate(input_shape); cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream); } @@ -43,8 +46,4 @@ void RegisterOps(Ort::CustomOpDomain& domain) { } // namespace Cuda -#else - -void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {} - #endif \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h index c0287c4932c98..35cd36fcd4cb7 100644 --- a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h +++ b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h @@ -5,6 +5,14 @@ namespace Cuda { +#if defined(USE_CUDA) && !defined(ENABLE_TRAINING) + void RegisterOps(Ort::CustomOpDomain& domain); -} \ No newline at end of file +#else + +void RegisterOps(Ort::CustomOpDomain&) {} + +#endif + +} // namespace Cuda \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index 40fb127eb0b8f..2d5ffc3c81b0f 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -13,6 +13,8 @@ #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" #include "cpu/cpu_ops.h" +#include "cuda/cuda_ops.h" +#include "rocm/rocm_ops.h" #include "onnxruntime_lite_custom_op.h" static const char* c_OpDomain = "test.customop"; @@ -31,10 +33,15 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA ORT_TRY { Ort::CustomOpDomain domain{c_OpDomain}; Cpu::RegisterOps(domain); - Ort::CustomOpDomain domain_v2{"v2"}; Cpu::RegisterOps(domain_v2); + Cuda::RegisterOps(domain); + Cuda::RegisterOps(domain_v2); + + Rocm::RegisterOps(domain); + Rocm::RegisterOps(domain_v2); + Ort::UnownedSessionOptions session_options(options); session_options.Add(domain); session_options.Add(domain_v2); diff --git a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc index 113bfb85454a2..069246b4201e7 100644 --- a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc @@ -19,7 +19,7 @@ using namespace Ort::Custom; throw std::runtime_error(msg); \ } -namespace Cuda { +namespace Rocm { void KernelOne(const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& X, @@ -38,10 +38,6 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_CustomOpOne.get()); } -} // namespace Cuda - -#else - -void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {} +} // namespace Rocm #endif \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h index 4e8958cd9dae0..d3e9e4040a5c3 100644 --- a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h +++ b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h @@ -5,6 +5,14 @@ namespace Rocm { +#ifdef USE_ROCM + void RegisterOps(Ort::CustomOpDomain& domain); -} \ No newline at end of file +#else + +inline void RegisterOps(Ort::CustomOpDomain&) {} + +#endif + +} // namespace Rocm From 444a0eda309e0fadf51c63790b6da78258f96a10 Mon Sep 17 00:00:00 2001 From: pengwa Date: Sat, 21 Oct 2023 19:45:45 +0800 Subject: [PATCH 05/24] Avoid one time clone to save memory peak (#17934) ### Avoid one more time clone to save memory peak --- .../_custom_autograd_function_runner.py | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index b9318033a3d53..dd32e2aced561 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -245,6 +245,8 @@ def _process_inplace_outputs( if not copied: # Only need a copy once. + # Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False. + raw_input_tensor.requires_grad = False raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index]) _log_warning( f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. " @@ -449,7 +451,8 @@ def call_python_forward_function( try: func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name # If this is the first time run, collect runtime tensor reuse mapping. - if kernel_invoke_id not in _GlobalOpKernelInfoMap: + is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap + if is_first_time_run: kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info @@ -473,6 +476,11 @@ def call_python_forward_function( if tensor_input_index in inplace_map: raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg + # Only requires gradient when running under training mode + # and the associated tensor has grad_flag=True (i.e., + # "requires_grad=True" in the original PyTorch script). + wrapped_arg.requires_grad = is_training_mode and grad_flag + # Note1: # If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the # copy for all tensors. Otherwise, we only copy the tensors whose indices are in @@ -480,29 +488,30 @@ def call_python_forward_function( # Note2: # For inference mode, we don't need to do the copy because ctx will be None, # so nothing will be saved for ctx. - if is_training_mode and ( - tensor_input_indices_to_save_in_ctx is None - or tensor_input_index in tensor_input_indices_to_save_in_ctx - ): - wrapped_arg = wrapped_arg.detach().clone() - - # Only requires gradient when running under training mode - # and the associated tensor has grad_flag=True (i.e., - # "requires_grad=True" in the original PyTorch script). - wrapped_arg.requires_grad = is_training_mode and grad_flag - # Note3: - # If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the - # mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in - # tensor_input_indices_for_mark_dirty. - if is_training_mode and ( - tensor_input_indices_for_mark_dirty is None - or tensor_input_index in tensor_input_indices_for_mark_dirty - ): - # To fix this issue: - # "a leaf Variable that requires grad has been used in an in-place operation." - with torch.set_grad_enabled(True): - wrapped_arg = wrapped_arg.clone() + # To fix this issue: + # "a leaf Variable that requires grad has been used in an in-place operation." + # If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the + # copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for + # the tensors whose indices are in tensor_input_indices_for_mark_dirty. + if is_training_mode: + if is_first_time_run: + with torch.set_grad_enabled(True): + wrapped_arg = wrapped_arg.clone() + else: + is_input_index_saved_in_ctx = ( + tensor_input_indices_to_save_in_ctx is None + or tensor_input_index in tensor_input_indices_to_save_in_ctx + ) + is_input_index_marked_dirty = ( + tensor_input_indices_for_mark_dirty is None + or tensor_input_index in tensor_input_indices_for_mark_dirty + ) + if is_input_index_saved_in_ctx or is_input_index_marked_dirty: + # when with grad, the leaf tensor after clone will not be leaf. + with torch.set_grad_enabled(is_input_index_marked_dirty): + wrapped_arg = wrapped_arg.clone() + wrapped_arg.requires_grad = is_training_mode and grad_flag wrapped_args.append(wrapped_arg) input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg From b7ae293be05c89a0cb623feec4d2d2cbf006e4c3 Mon Sep 17 00:00:00 2001 From: JiCheng Date: Sun, 22 Oct 2023 23:33:29 +0800 Subject: [PATCH 06/24] Support large model export using multi-gpu (#17990) ### Description This PR is to implemente a exporter which works for large language models(LLM). It works for models like Llama2-70b or gpt-175. The main idea is to utilize multiple-GPU and dispatch differnet layers to different GPU, in short, it symply implemented auto pipeline parallelism. For example : to export Llama2-70b, you need 8x V100-32GB or 4x A100-80G or More GPU memories. It would expect to export decoder-only models. For encoder-decoder arch-like models, we didn't test it yet. ### Motivation and Context --------- Co-authored-by: Justin Chu --- .../transformers/large_model_exporter.py | 385 ++++++++++++++++++ 1 file changed, 385 insertions(+) create mode 100644 onnxruntime/python/tools/transformers/large_model_exporter.py diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py new file mode 100644 index 0000000000000..3b344d6dc9342 --- /dev/null +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -0,0 +1,385 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Export LLM to onnx +""" +import argparse +import inspect +import math +import os +import tempfile +from pathlib import Path +from typing import Optional + +import onnx +import torch +import transformers +from torch import nn + + +def disable_huggingface_init(): + """do not init model twice as it slow initialization""" + + torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.normal_ = lambda x, *args, **kwargs: x + torch.nn.init.constant_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x + + +def get_model_parameter_size(model: nn.Module): + """to calculate how much memory this model needs""" + param_size = 0 + param_sum = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + param_sum += param.nelement() + buffer_size = 0 + buffer_sum = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + buffer_sum += buffer.nelement() + all_size = (param_size + buffer_size) / 1024 / 1024 + return all_size + + +def initialize_model_and_sample_inputs(hf_model: str, cache_dir: Optional[str], tokenizer=None): + """ + get the pretrained torch model from hugginface, + and sample model-inputs + """ + + disable_huggingface_init() + + model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore + hf_model, torch_dtype=torch.float16, cache_dir=cache_dir, trust_remote_code=True + ) + if tokenizer is None: + tokenizer = hf_model + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore + + sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values()) + return model, sample_inputs + + +def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple): + """Make the model executable across multiple GPUs.""" + + def input_gpu_device_hook(mod, inputs, kwargs): + modifyed_inputs = [] + first_dev = None + for layer_input in inputs: + if type(layer_input) is not torch.Tensor: + modifyed_inputs.append(layer_input) + elif hasattr(mod, "weight"): + modifyed_inputs.append(layer_input.to(mod.weight.device)) + elif hasattr(mod, "parameters"): + device = next(mod.parameters(), layer_input).device + modifyed_inputs.append(layer_input.to(device)) + elif hasattr(next(mod.children(), None), "weight"): + modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device)) + elif first_dev is not None and layer_input.device != first_dev: + modifyed_inputs.append(layer_input.to(first_dev)) + else: + modifyed_inputs.append(layer_input) + if first_dev is None: + first_dev = modifyed_inputs[0].device + for key, value in kwargs.items(): + if type(value) is torch.Tensor: + kwargs[key] = value.to(first_dev) + + return (tuple(modifyed_inputs), kwargs) + + def move_layer_to_device_rurc(mod, dev): + mod.to(dev) + for layer in mod.named_children(): + move_layer_to_device_rurc(layer[1], dev) + + model = model.half() + all_hooks = [] + all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + pre_fix = next(iter(model.named_children()))[0] + for top_name, top_module in model.named_children(): + for name, module in top_module.named_children(): + all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + if type(module) in [torch.nn.ModuleList]: + num_layers_on_each_gpu = math.floor(len(module) / len(gpulist)) + for idx, attn_layer in enumerate(module): + all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + + to_dev = gpulist[min(idx // num_layers_on_each_gpu, len(gpulist))] + attn_layer.to(to_dev) + move_layer_to_device_rurc(attn_layer, to_dev) + print(f"move {pre_fix}.{name}.{idx} to {to_dev}") + else: + module.to(gpulist[0]) + print(f"move {pre_fix}.{name} to {gpulist[0]}") + if len(list(top_module.named_children())) == 0: + top_module.to(gpulist[0]) + print(f"move {top_name} to {gpulist[0]}") + + with torch.no_grad(): + model(sample_inputs[0], attention_mask=sample_inputs[1]) + return model + + +def retrieve_onnx_inputs(model: nn.Module, sample_inputs: tuple, with_past: bool): + """ + auto retrieve onnx inputs from torch model as we can't enumlate all possibilities + for all models + """ + user_inputs = [] + + def hook_for_inputs(_, inputs, kwargs): + user_inputs.append((inputs, kwargs)) + return user_inputs[0] + + hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True) + + forward_params = inspect.signature(model.forward).parameters + input_keys = list(forward_params.keys()) + default_values = [forward_params.get(key).default for key in input_keys] + out = model(sample_inputs[0], attention_mask=sample_inputs[1]) + hook_handle.remove() + user_inputs = user_inputs[0] + onnx_inputs = default_values + for idx, _val in enumerate(user_inputs[0]): + onnx_inputs[idx] = user_inputs[0][idx] + for key, value in user_inputs[1].items(): + idx = input_keys.index(key) + onnx_inputs[idx] = value + for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): + if type(value) is torch.Tensor: + value.to(model.device) + # Didn't touch past_key_value now, please change it if you want + if "use_cache" in key: + onnx_inputs[idx] = with_past + + return input_keys, onnx_inputs, out.past_key_values + + +def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: + """ + According to the model size, we will upload it to + CPU if has no GPU or enough GPU memory, + Single GPU if has only one GPU in local or model size is enough to fit one GPU + Multiple GPU if there is more than one gpu in local and model is too large + """ + total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 + + print(f"Model_Size = {get_model_parameter_size(model)/1024} GB") + print(f"total_mem_per_cpu = {total_mem_per_cpu/1024} GB") + if get_model_parameter_size(model) > total_mem_per_cpu * 0.45: + device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] + if len(device_collection) > 1: + print( + f"{len(device_collection)} GPUs are used to export onnx, \ + Please set CUDA_VISIBLE_DEVICES to use specific GPU group" + ) + model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp) + else: + print("!!!! convert model to float and export onnx using CPU") + model = model.cpu().float() + else: + print("Export model on a single GPU") + model = model.cuda().half() + return model + + +def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple: + """move inputs to device""" + sample_inputs_ = [] + for sample_int in sample_inputs: + if isinstance(sample_int, torch.Tensor): + sample_inputs_.append(sample_int.to(device)) + else: + sample_inputs_.append(sample_int) + return tuple(sample_inputs_) + + +def fetch_onnx_inputs_outputs_name( + model: nn.Module, + onnx_inputs: list, + torch_input_names: tuple, + past_key_values: tuple, + with_past: bool, + input_with_past: bool, +): + """fetch onnx inputs and outputs name""" + num_of_past_key = 0 + kv_cache_axis = {0: "batch_size"} + # try get num_of_past_key and shape of past_key_value + if past_key_values is not None: + num_of_past_key = len(past_key_values) + seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1) + assert seq_index.numel() == 1 + kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"} + + if not num_of_past_key: + num_of_past_key = model.config.num_hidden_layers + + onnx_inp_names = ("input_ids", "attention_mask") + onnx_out_names = ("logits",) + onnx_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "attention_mask": {0: "batch_size", 1: "seq_len"}, + } + if input_with_past: + for i in range(num_of_past_key): + onnx_inp_names += (f"present_key.{i}",) + onnx_inp_names += (f"present_values.{i}",) + + onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis + + if with_past or input_with_past: + for i in range(num_of_past_key): + onnx_out_names += (f"past_key.{i}",) + onnx_out_names += (f"past_values.{i}",) + onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis + + for idx, name in enumerate(torch_input_names): + if input_with_past: + if name == "past_key_values": + onnx_inputs[idx] = past_key_values + elif name == "attention_mask": + attn_mask = onnx_inputs[idx] + onnx_inputs[idx] = torch.cat( + (attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device)), dim=1 + ) + elif name == "input_ids": + input_ids = onnx_inputs[idx] + onnx_inputs[idx] = input_ids[:, -1:] + + return onnx_inp_names, onnx_out_names, onnx_dynamic_axes + + +def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int): + """do export with torch.onnx.export""" + onnx_model_name = onnx_path.name + onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple + # two step to export onnx + # 1. export onnx with lots of pieces of weights + # 2. save all weights to external data + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_onnx = os.path.join(tmpdirname, "tmp.onnx") + + torch.onnx.export( + model=model, + args=tuple(onnx_inputs), + f=tmp_onnx, + verbose=False, + opset_version=opset, + input_names=onnx_inp_names, + output_names=onnx_out_names, + dynamic_axes=onnx_dynamic_axes, + ) + + onnx_path.unlink(missing_ok=True) + (onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True) + + onnx_model = onnx.load(str(tmp_onnx)) + onnx.save_model( + onnx_model, + str(onnx_path), + save_as_external_data=(len(os.listdir(tmpdirname)) > 1), + all_tensors_to_one_file=True, + location=f"{onnx_model_name}_ext.data", + size_threshold=1024, + convert_attribute=False, + ) + + +@torch.no_grad() +def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): + """ + do export + model: torch model + onnx_path: where the onnx model saved to + sample_inputs_tp: inputs for torch model + """ + model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) + + model = move_to_approprate_device(model, sample_inputs_tp) + + sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) + + # input_keys would be usesful if the model has some special inputs + input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past) + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False) + + onnx_model_name = "model.onnx" + onnx_path: Path = Path(onnx_path_str).absolute() + if onnx_path.suffix != ".onnx": + onnx_path = onnx_path / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + if not with_past: + return + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True) + + onnx_model_name = "model_with_past.onnx" + onnx_path = onnx_path.parent / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + + +def parse_arguments(): + """arguments parsing.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model", + required=True, + type=str, + default=["meta-llama/Llama-2-70b-hf"], + help="Pre-trained models in huggingface model hub", + ) + parser.add_argument( + "-s", + "--saved_path", + required=False, + type=str, + default="./onnx_models/", + help="where the onnx model will be saved", + ) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default=None, + help=("cache directy of huggingface, by setting this to avoid useless downloading if you have one"), + ) + parser.add_argument( + "--with_past", + action="store_true", + default=False, + help=("The tool will export onnx without past-key-value by default"), + ) + parser.add_argument( + "--opset", + required=False, + type=int, + default=17, + help=( + "the opset to save onnx model, \ + try to increase it if this opset doens't have new features you want" + ), + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + + export_onnx(args.model, args.cache_dir, args.saved_path, args.with_past, args.opset) From f0d5ea5930bee7e87c1d93e14d4d28c8af3c8cde Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 23 Oct 2023 09:01:29 -0700 Subject: [PATCH 07/24] [QNN EP] Disable flaky test QnnCPUBackendTests.MatMulOp_Broadcast (#18033) Disable flaky test QnnCPUBackendTests.MatMulOp_Broadcast. The test failed on Linux randomly. --- onnxruntime/test/providers/qnn/matmul_test.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index e721ccbcb45a9..3073dde9d8e4c 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -112,12 +112,13 @@ TEST_F(QnnCPUBackendTests, MatMulOp) { } // Test MatMul broadcasting -// Note slight inaccuracy in CPU backend: +// Failed randomly on Linux +// Value of: expected_tensor.DataAsSpan() // Expected: contains 896 values, where each value and its corresponding value in 16-byte object -// <80-03 00-00 00-00 00-00 40-00 34-DD F7-01 00-00> are an almost-equal pair -// Actual: 16-byte object <80-03 00-00 00-00 00-00 40-00 23-DD F7-01 00-00>, -// where the value pair (73.68116, 73.680809) at index #80 don't match, which is -0.000350952 from 73.6812 -TEST_F(QnnCPUBackendTests, MatMulOp_Broadcast) { +// <80-03 00-00 00-00 00-00 40-B8 53-08 CC-7F 00-00> are an almost-equal pair +// Actual: 16-byte object <80-03 00-00 00-00 00-00 C0-B7 43-08 CC-7F 00-00>, where the value pair +// (-5.19657087, 0) at index #29 don't match, which is 5.19657 from -5.19657 +TEST_F(QnnCPUBackendTests, DISABLED_MatMulOp_Broadcast) { // Create two matrices with element values in the range [-10.0, 10.0]. std::vector input_a = GetFloatDataInRange(-10.0f, 10.0f, 28 * 64); std::vector input_b = GetFloatDataInRange(-10.0f, 10.0f, 64 * 32); From 8a12b2cea6c80f312045f4ac74b818cb5b53fa35 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 24 Oct 2023 02:02:19 +0800 Subject: [PATCH 08/24] [js/webgpu] Fix the transpose error when dims > 4D (#18027) ### Description Currently, the uniform support has bugs when dims rank is larger than 4. See https://github.com/microsoft/onnxruntime/issues/17860 item 1. So this PR only enables shapes uniforms when shape rank is <= 4 for transpose. Otherwise, below compilation errors are thrown: ``` 1 error(s) generated while compiling the shader: :3:50 error: uniform storage requires that array elements are aligned to 16 bytes, but array element of type 'u32' has a stride of 4 bytes. Consider using a vector or struct as the element type instead. struct Uniforms { output_size:u32, a_shape:array, a_strides:array, output_shape:array, output_strides:array }; ^^^^^^^^^^^^^ :3:7 note: see layout of struct: /* align(4) size(84) */ struct Uniforms { /* offset( 0) align(4) size( 4) */ output_size : u32; /* offset( 4) align(4) size(20) */ a_shape : array; /* offset(24) align(4) size(20) */ a_strides : array; /* offset(44) align(4) size(20) */ output_shape : array; /* offset(64) align(4) size(20) */ output_strides : array; /* */ }; struct Uniforms { output_size:u32, a_shape:array, a_strides:array, output_shape:array, output_strides:array }; ^^^^^^ :4:42 note: 'Uniforms' used in address space 'uniform' here @group(0) @binding(2) var uniforms: Uniforms; ^^^^^^^^ ``` --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 3 ++ .../wasm/jsep/webgpu/ops/conv-transpose.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 51 +++++++++++-------- js/web/test/data/ops/transpose.jsonc | 24 +++++++++ 5 files changed, 59 insertions(+), 25 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0a64d1ad1792a..1d3fc78fe368a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -803,3 +803,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly } return dims; }; + +// TODO: remove this limitation once >4D dims are supported by uniform. +export const enableShapesUniforms = (rank: number): boolean => rank <= 4; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index d241b8b92a669..e880afe09a5d8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -232,7 +232,7 @@ const convTranspose2d = // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposePerm), + createTransposeProgramInfo(inputs[1], weightTransposePerm), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index b323a36cee5c8..c7ea0cffe51c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -168,7 +168,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut if (isChannelsLast) { const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute), + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; @@ -208,7 +208,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute), + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index fe556a7fd8552..c4d43e9f466f5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -35,13 +35,18 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou return reverseFunc.join('\n'); }; -export const createTransposeProgramInfo = - (inputDataType: number, inputRank: number, permAttr: number[]): ProgramInfo => { - const perm = getAdjustedPerm(inputRank, permAttr); - const output = outputVariable('output', inputDataType, (permAttr && permAttr.length) || inputRank); - const input = inputVariable('a', inputDataType, inputRank); +export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { + const inputDataType = inputTensor.dataType; + const inputRank = inputTensor.dims.length; + const perm = getAdjustedPerm(inputRank, permAttr); + const useShapesUniforms = enableShapesUniforms(inputRank); + const outputShape = getOutputShape(inputTensor.dims, perm); + const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; + const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; + const output = outputVariable('output', inputDataType, outShapeOrRank); + const input = inputVariable('a', inputDataType, inShapeOrRank); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${permFunctionBody(perm, inputRank, input, output)} @@ -54,30 +59,32 @@ export const createTransposeProgramInfo = ${output.setByOffset('global_idx', input.getByIndices('aIndices'))} }`; + return { + name: 'Transpose', + shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + getRunData: (inputs) => { + const outputSize = ShapeUtil.size(outputShape); return { - name: 'Transpose', - shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, - getRunData: (inputs) => { - const outputShape = getOutputShape(inputs[0].dims, perm); - const outputSize = ShapeUtil.size(outputShape); - return { - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: useShapesUniforms ? + [ {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape), + ] : + [ + {type: 'uint32', data: outputSize}, ], - }; - }, - getShaderSource, }; - }; + }, + getShaderSource, + }; +}; export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => { validateInputs(context.inputs); - context.compute( - createTransposeProgramInfo(context.inputs[0].dataType, context.inputs[0].dims.length, attributes.perm)); + context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm)); }; export const parseTransposeAttributes = (attributes: Record): TransposeAttributes => diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index 285d14018e74d..e1edfa7e41513 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -166,5 +166,29 @@ ] } ] + }, + { + "name": "Transpose 5D - perms:[4, 3, 1, 0, 2]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }], + "cases": [ + { + "name": "T[3, 1, 2, 1, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [3, 1, 2, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24], + "dims": [4, 1, 1, 3, 2], + "type": "float32" + } + ] + } + ] } ] From 2a17d5cf32900fa0100959eace6e303c82f86bdc Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Mon, 23 Oct 2023 13:00:56 -0700 Subject: [PATCH 09/24] LLaMA Model Optimization (#18021) ### Description This PR contains fusion-level and kernel-level optimizations for [Meta's LLaMA-2](https://blogs.microsoft.com/blog/2023/07/18/microsoft-and-meta-expand-their-ai-partnership-with-llama-2-on-azure-and-windows/). Some of the added optimizations include: - SimplifiedLayerNorm changes - Fusions for multiple variants - SkipSimplifiedLayerNorm changes - Kernel support for CPU - Rotary embeddings (previously did not exist) - Fusions for multiple variants - CPU and CUDA kernels - Supports interleaving and non-interleaving in the same kernels - Optimized cache that requires half of its originally exported sizes - Reduced from `(max_sequence_length, head_size)` to `(max_sequence_length, head_size / 2)` - Multi-head attention - Support for 2D and 3D attention masks - Group query attention (for FP16 CUDA and INT4 CUDA) - Integration with flash attention v2 and past-present buffer sharing - Removes need for `attention_mask` input as it is supported in the kernel - 4 bit quantization - `block_size` parameter is available for customizing - Support the new changes for [Microsoft version](https://github.com/microsoft/Llama-2-Onnx) - Support combinations of the below variants (ex: export ORT version and run with Optimum) Supported variants of LLaMA-2 include: - [ORT version](https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/llama) - Produces one ONNX file that is already optimized (and quantized if requested) - Integrates with Optimum - [Another Microsoft version](https://github.com/microsoft/Llama-2-Onnx) - Already exported and available off-the-shelf - Faster versions of those models will be uploaded there soon - [Hugging Face version](https://huggingface.co/meta-llama) - Models that end with `-hf` - Some older and current versions of [`transformers`](https://github.com/huggingface/transformers) and [`optimum`](https://github.com/huggingface/optimum) that export the model to ONNX differently - Note that while some older versions are supported, it is recommended to use the latest package versions. ### Usage To use the optimizations, please see `README.md` for details. Please note the various `requirements.txt` files for the package versions recommended in order to use these changes. To run the ORT transformer optimizer separately, run the script as follows: ``` $ cd onnxruntime/onnxruntime/python/tools/transformers/ $ python3 optimizer.py --input .onnx --output .onnx --model_type gpt2 --num_heads --hidden_size