From 727770870a0bad5bf62f750d6c2252ed78ff0128 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 20 Dec 2023 23:59:54 +0000 Subject: [PATCH 01/17] first try --- onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc | 6 ++++++ onnxruntime/core/optimizer/layer_norm_fusion.cc | 2 +- .../training_ops/cuda/cuda_training_kernels.cc | 8 ++++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 7875ac75b8188..47f02ad27af41 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -118,6 +118,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_BFloat16, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); @@ -314,6 +317,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 159e3b23d1ab0..b6ad4fde6c1f7 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -13,7 +13,7 @@ using namespace onnxruntime::common; namespace onnxruntime { // LayerNorm supports limited data types. -static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"}; +static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}; // Default epsilon static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f; diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index dcf733153bdad..1eb074a119a74 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -154,6 +154,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_BFloat16, SimplifiedLayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_float, SimplifiedLayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, InvertibleLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, InvertibleLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16, InvertibleLayerNormalizationGrad); @@ -196,6 +199,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MixedPrecisionScale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); @@ -410,6 +414,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -452,6 +459,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From b52c67ff05a2ee36f5c0904ec246009ee87737a9 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Thu, 21 Dec 2023 20:30:03 +0000 Subject: [PATCH 02/17] add unit test --- onnxruntime/test/common/tensor_op_test_utils.h | 9 +++++++++ onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 11 +++++++++++ 2 files changed, 20 insertions(+) diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index 6917aa15777a2..277fdd71f1d85 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -156,6 +156,15 @@ inline std::vector ToFloat16(const std::vector& data) { return result; } +inline std::vector ToBFloat16(const std::vector& data) { + std::vector result; + result.reserve(data.size()); + for (size_t i = 0; i < data.size(); i++) { + result.push_back(BFloat16(data[i])); + } + return result; +} + inline void CheckTensor(const Tensor& expected_tensor, const Tensor& output_tensor, double rtol, double atol) { ORT_ENFORCE(expected_tensor.Shape() == output_tensor.Shape(), "Expected output shape [" + expected_tensor.Shape().ToString() + diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 84bbee35eed5a..edbe94904da30 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -75,6 +75,17 @@ TEST(LayerNormTest, LayerNorm) { test.Run(); } +TEST(LayerNormTest, LayerNorm_BFloat16Input) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, ToBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("gamma", {3}, {1.0f, 1.0f, 1.0f}); + test.AddOutput("output", dims, {-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f}); + test.Run(); +} + TEST(LayerNormTest, LayerNorm_Scale) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); From 472c0dc962e38dd589c6d5dfbb392a29f7e2fffd Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Fri, 22 Dec 2023 17:06:18 +0000 Subject: [PATCH 03/17] remove redundant declarations and fix bug in unit test --- onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc | 4 ---- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 4 ++-- .../orttraining/training_ops/cuda/cuda_training_kernels.cc | 4 ---- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 47f02ad27af41..0364fc87eb1a0 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -119,8 +119,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_BFloat16, SimplifiedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); @@ -318,8 +316,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index edbe94904da30..488513684c3c5 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -81,8 +81,8 @@ TEST(LayerNormTest, LayerNorm_BFloat16Input) { std::vector dims{1, 2, 3}; test.AddInput("x", dims, ToBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); - test.AddInput("gamma", {3}, {1.0f, 1.0f, 1.0f}); - test.AddOutput("output", dims, {-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f}); + test.AddInput("gamma", {3}, ToBFloat16({1.0f, 1.0f, 1.0f})); + test.AddOutput("output", dims, ToBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); test.Run(); } diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 1eb074a119a74..788ff76490c9e 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -154,8 +154,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_BFloat16, SimplifiedLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_float, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, InvertibleLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, InvertibleLayerNormalizationGrad); @@ -414,8 +412,6 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From 822066a5d0c6017508522bd14dbc549f2eb99fcf Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Fri, 22 Dec 2023 18:29:56 +0000 Subject: [PATCH 04/17] remove redundant grad registration --- .../orttraining/training_ops/cuda/cuda_training_kernels.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 788ff76490c9e..8b2bc7e2ef2b3 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -154,7 +154,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalizationGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, InvertibleLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, InvertibleLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float_MLFloat16, InvertibleLayerNormalizationGrad); @@ -412,7 +411,6 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From 0ae59f051aaa796ffbc3733fe8653d6a18b92b5b Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Fri, 22 Dec 2023 18:43:42 +0000 Subject: [PATCH 05/17] restrict test to cuda --- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 488513684c3c5..e887b0ecda200 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -83,7 +83,7 @@ TEST(LayerNormTest, LayerNorm_BFloat16Input) { test.AddInput("x", dims, ToBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); test.AddInput("gamma", {3}, ToBFloat16({1.0f, 1.0f, 1.0f})); test.AddOutput("output", dims, ToBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale) { From 56f5c42bb68b693b6264d9d380c03cbdc678631c Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Fri, 22 Dec 2023 19:02:39 +0000 Subject: [PATCH 06/17] check for bf16 inside test --- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index e887b0ecda200..15f9789114d7f 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -76,6 +76,13 @@ TEST(LayerNormTest, LayerNorm) { } TEST(LayerNormTest, LayerNorm_BFloat16Input) { + #ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } + #endif OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); From 1b7927e504cfb9befdb507ec20a4538f30ad907c Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 8 Jan 2024 14:50:52 +0000 Subject: [PATCH 07/17] formatting --- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 15f9789114d7f..ec19af3f5d1f5 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -82,7 +82,7 @@ TEST(LayerNormTest, LayerNorm_BFloat16Input) { LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; return; } - #endif + #endif OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); From e49b8592d5702ddf1b2053629320750f34c9ce96 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 8 Jan 2024 17:45:46 +0000 Subject: [PATCH 08/17] ToBFloat16 -> MakeBFloat16 --- onnxruntime/test/common/tensor_op_test_utils.h | 9 --------- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 13 +++---------- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/onnxruntime/test/common/tensor_op_test_utils.h b/onnxruntime/test/common/tensor_op_test_utils.h index 277fdd71f1d85..6917aa15777a2 100644 --- a/onnxruntime/test/common/tensor_op_test_utils.h +++ b/onnxruntime/test/common/tensor_op_test_utils.h @@ -156,15 +156,6 @@ inline std::vector ToFloat16(const std::vector& data) { return result; } -inline std::vector ToBFloat16(const std::vector& data) { - std::vector result; - result.reserve(data.size()); - for (size_t i = 0; i < data.size(); i++) { - result.push_back(BFloat16(data[i])); - } - return result; -} - inline void CheckTensor(const Tensor& expected_tensor, const Tensor& output_tensor, double rtol, double atol) { ORT_ENFORCE(expected_tensor.Shape() == output_tensor.Shape(), "Expected output shape [" + expected_tensor.Shape().ToString() + diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index ec19af3f5d1f5..4b412516b804c 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -76,20 +76,13 @@ TEST(LayerNormTest, LayerNorm) { } TEST(LayerNormTest, LayerNorm_BFloat16Input) { - #ifdef USE_CUDA - int min_cuda_architecture = 530; - if (!HasCudaEnvironment(min_cuda_architecture)) { - LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; - return; - } - #endif OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); std::vector dims{1, 2, 3}; - test.AddInput("x", dims, ToBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); - test.AddInput("gamma", {3}, ToBFloat16({1.0f, 1.0f, 1.0f})); - test.AddOutput("output", dims, ToBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); + test.AddInput("x", dims, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("gamma", {3}, MakeBFloat16({1.0f, 1.0f, 1.0f})); + test.AddOutput("output", dims, MakeBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); } From 490360ef03d42942c9a057a1fc21799bc8c60892 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 8 Jan 2024 19:24:14 +0000 Subject: [PATCH 09/17] IRv4 --- onnxruntime/core/providers/cuda/tensor/shape_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/tensor/shape_op.cc b/onnxruntime/core/providers/cuda/tensor/shape_op.cc index 0d5da81fe256b..a109d451afd5f 100644 --- a/onnxruntime/core/providers/cuda/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/shape_op.cc @@ -40,7 +40,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( (*KernelDefBuilder::Create()) // properly force CPU/GPU synch inside the kernel .OutputMemoryType(OrtMemTypeCPUInput, 0) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv4()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); From bf24ce768f46a790f382d67ac3ec0cf5d34d4415 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 8 Jan 2024 20:45:23 +0000 Subject: [PATCH 10/17] revert shape change --- onnxruntime/core/providers/cuda/tensor/shape_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/tensor/shape_op.cc b/onnxruntime/core/providers/cuda/tensor/shape_op.cc index a109d451afd5f..0d5da81fe256b 100644 --- a/onnxruntime/core/providers/cuda/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/shape_op.cc @@ -40,7 +40,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( (*KernelDefBuilder::Create()) // properly force CPU/GPU synch inside the kernel .OutputMemoryType(OrtMemTypeCPUInput, 0) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv4()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); From faf79e9a3e73714d7ee9a3bf1139a144688ccc06 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 8 Jan 2024 20:53:58 +0000 Subject: [PATCH 11/17] check for bf16 support before running test --- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 4b412516b804c..a767067fe9af0 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -7,6 +7,7 @@ #include "core/session/inference_session.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" #include "test/framework/test_utils.h" #include "test/util/include/default_providers.h" #include "test/providers/provider_test_utils.h" @@ -76,6 +77,13 @@ TEST(LayerNormTest, LayerNorm) { } TEST(LayerNormTest, LayerNorm_BFloat16Input) { + #ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } + #endif OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); From 17e72dc8a3e287d35a0ae38837d4f34f7d295673 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 8 Jan 2024 21:16:19 +0000 Subject: [PATCH 12/17] lint --- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index a767067fe9af0..660720437bddb 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -77,13 +77,13 @@ TEST(LayerNormTest, LayerNorm) { } TEST(LayerNormTest, LayerNorm_BFloat16Input) { - #ifdef USE_CUDA - int min_cuda_architecture = 530; - if (!HasCudaEnvironment(min_cuda_architecture)) { - LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; - return; - } - #endif +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } +#endif OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); From 0bdab2fc8bab7a696d160c5dc1ead2e23265ac7b Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 8 Jan 2024 22:23:41 +0000 Subject: [PATCH 13/17] limit tests to cuda --- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 660720437bddb..d7df43e121e55 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -91,7 +91,10 @@ TEST(LayerNormTest, LayerNorm_BFloat16Input) { test.AddInput("x", dims, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); test.AddInput("gamma", {3}, MakeBFloat16({1.0f, 1.0f, 1.0f})); test.AddOutput("output", dims, MakeBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale) { From 746e7d593bbdeda711ca65fac821ad0280d4b787 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 17 Jan 2024 22:40:20 +0000 Subject: [PATCH 14/17] formatting --- docs/OperatorKernels.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f985cf10ded60..e528b0c19826f 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -355,7 +355,7 @@ Do not modify directly.* |||[6, 12]|**T** = tensor(double), tensor(float)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[9, 12]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)| |Size|*in* data:**T**
*out* size:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| From fbe8b1c8641e2f5705c00564b9a319b6a593d3e5 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 13 Feb 2024 19:38:33 +0000 Subject: [PATCH 15/17] docs update --- docs/OperatorKernels.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index e528b0c19826f..9c5460e2787a5 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -355,7 +355,7 @@ Do not modify directly.* |||[6, 12]|**T** = tensor(double), tensor(float)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[9, 12]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float)
**U** = tensor(bfloat16), tensor(double), tensor(float)
**V** = tensor(double), tensor(float)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)| |Size|*in* data:**T**
*out* size:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| From 97a440b381a802ff210201f2d0d404570c720824 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 Feb 2024 21:47:36 +0000 Subject: [PATCH 16/17] add comment --- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index d7df43e121e55..98fb62e435f31 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -77,6 +77,7 @@ TEST(LayerNormTest, LayerNorm) { } TEST(LayerNormTest, LayerNorm_BFloat16Input) { +// prevents test from running on non-BF16-supporting hardware #ifdef USE_CUDA int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { From 647789d506c7ded5a9b3149686355c8fe5b76b60 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 13 Feb 2024 22:49:22 +0000 Subject: [PATCH 17/17] update docs --- docs/OperatorKernels.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9c5460e2787a5..789b6e4cce901 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -355,7 +355,7 @@ Do not modify directly.* |||[6, 12]|**T** = tensor(double), tensor(float)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[9, 12]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float)
**U** = tensor(bfloat16), tensor(double), tensor(float)
**V** = tensor(double), tensor(float)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)| |Size|*in* data:**T**
*out* size:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -758,7 +758,7 @@ Do not modify directly.* |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|