diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index f73efcddcedd4..24a5dcab225c4 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -24,16 +24,16 @@ void ComputeJob( const T* bias_data, const ptrdiff_t task_idx, const int64_t norm_size, - IAllocatorUniquePtr& scale_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, + const float* scale_float_ptr, + const float* bias_float_ptr, float epsilon, bool simplified, T* Y_data, U* mean_data, U* inv_std_dev_data, AllocatorPtr alloc) { - ORT_UNUSED_PARAMETER(scale_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(scale_float_ptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload ORT_UNUSED_PARAMETER(alloc); const T* p_input = X_data + task_idx * norm_size; @@ -82,14 +82,17 @@ void ComputeJob( const MLFloat16* bias_data, const ptrdiff_t task_idx, const int64_t norm_size, - IAllocatorUniquePtr& scale_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, + const float* scale_float_ptr, + const float* bias_float_ptr, float epsilon, bool simplified, MLFloat16* Y_data, U* mean_data, U* inv_std_dev_data, AllocatorPtr alloc) { + ORT_UNUSED_PARAMETER(scale_data); // only used in float/double overload + ORT_UNUSED_PARAMETER(bias_data); // only used in float/double overload + const MLFloat16* p_input = X_data + task_idx * norm_size; MLFloat16* p_output = Y_data + task_idx * norm_size; @@ -117,22 +120,10 @@ void ComputeJob( mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); } - if (!scale_float_uptr) { - scale_float_uptr = std::move(input_float_uptr); // overwrite input with scale values, since they have the same size - MlasConvertHalfToFloatBuffer(scale_data, scale_float_uptr.get(), num_elems); - } - - if (bias_data && !bias_float_uptr) { - bias_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems); - } - - const float* scale_float_ptr = scale_float_uptr.get(); - const float* bias_float_ptr = bias_float_uptr.get(); for (size_t h = 0; h < num_elems; h++) { if (simplified) { output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h]; - } else if (nullptr == bias_data) { + } else if (nullptr == bias_float_ptr) { output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h]; } else { output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h]; @@ -166,7 +157,13 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I } // namespace LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op) - : OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op}, scale_fp32_(nullptr), bias_fp32_(nullptr) { + : OpKernel(op_kernel_info), + simplified_{simplified}, + contrib_op_{contrib_op}, + prepacked_scale_fp32_data_(nullptr), + prepacked_scale_fp32_size_(0), + prepacked_bias_fp32_data_(nullptr), + prepacked_bias_fp32_size_(0) { ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); } @@ -175,15 +172,15 @@ template Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const { // Inputs const Tensor* X = p_ctx->Input(0); - const Tensor* scale = p_ctx->Input(1); - const Tensor* bias = p_ctx->Input(2); + const Tensor* scale = prepacked_scale_fp32_data_ ? nullptr : p_ctx->Input(1); + const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(2); const T* X_data = X->Data(); - const T* scale_data = scale->Data(); + const T* scale_data = scale ? scale->Data() : nullptr; const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data(); const TensorShape& x_shape = X->Shape(); - const TensorShape& scale_shape = scale->Shape(); - const TensorShape& bias_shape = bias->Shape(); + size_t scale_size = scale ? static_cast(scale->Shape().Size()) : prepacked_scale_fp32_size_; + size_t bias_size = bias ? static_cast(bias->Shape().Size()) : prepacked_bias_fp32_size_; Tensor* Y = p_ctx->Output(0, x_shape); T* Y_data = Y->MutableData(); @@ -218,7 +215,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo AllocatorPtr alloc; ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); - return ComputeWithoutContext(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data, + return ComputeWithoutContext(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data, inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc); } @@ -237,9 +234,11 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr is_packed = false; if (input_idx == 1) { // scale - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, scale_fp32_, is_packed); + prepacked_scale_fp32_size_ = static_cast(tensor.Shape().Size()); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed); } else if (input_idx == 2) { // bias - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed); + prepacked_bias_fp32_size_ = static_cast(tensor.Shape().Size()); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); } return Status::OK(); @@ -250,9 +249,9 @@ Status LayerNormImpl::ComputeWithoutContext( const T* X_data, const TensorShape& x_shape, const T* scale_data, - const TensorShape& scale_shape, + size_t scale_size, const T* bias_data, - const TensorShape& bias_shape, + size_t bias_size, T* Y_data, U* mean_data, U* inv_std_dev_data, @@ -264,19 +263,34 @@ Status LayerNormImpl::ComputeWithoutContext( int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis)); int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); - const auto scale_size = scale_shape.Size(); - const auto bias_size = (bias_data) ? bias_shape.Size() : 0; - if (scale_size != norm_size || (bias_data && bias_size != norm_size)) { + if (static_cast(scale_size) != norm_size || (bias_data && static_cast(bias_size) != norm_size)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size of X.shape()[axis:] == ", norm_size, ". Size of scale and bias (if provided) must match this. Got scale size of ", scale_size, " and bias size of ", bias_size); } + IAllocatorUniquePtr scale_fp32; + IAllocatorUniquePtr bias_fp32; + if constexpr (std::is_same_v) { + if (prepacked_scale_fp32_data_ == nullptr) { + const size_t num_elems = static_cast(norm_size); + scale_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems); + } + if (prepacked_bias_fp32_data_ == nullptr && bias_data) { + const size_t num_elems = static_cast(norm_size); + bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); + } + } + concurrency::ThreadPool::TryBatchParallelFor( thread_pool, static_cast(norm_count), [&](ptrdiff_t task_idx) { - ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, scale_fp32_, bias_fp32_, + ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, + prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(), + prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(), epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc); }, 0); diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h index f6325c31cc71a..f8b528b398cba 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h @@ -24,9 +24,9 @@ class LayerNormImpl : public OpKernel { const T* X_data, const TensorShape& x_shape, const T* scale_data, - const TensorShape& scale_shape, + size_t scale_size, const T* bias_data, - const TensorShape& bias_shape, + size_t bias_size, T* Y_data, U* mean_data, U* inv_std_dev, @@ -63,8 +63,10 @@ class LayerNormImpl : public OpKernel { float epsilon_; const bool simplified_; const bool contrib_op_; - mutable IAllocatorUniquePtr scale_fp32_; - mutable IAllocatorUniquePtr bias_fp32_; + IAllocatorUniquePtr prepacked_scale_fp32_data_; + size_t prepacked_scale_fp32_size_; + IAllocatorUniquePtr prepacked_bias_fp32_data_; + size_t prepacked_bias_fp32_size_; }; } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 655c4951f262d..9ecaa16a2ab24 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -151,6 +151,20 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); } +TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput_Initializers) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, ToFloat16({-10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), true); + test.AddOutput("output", dims, ToFloat16({0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f})); + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); +} + TEST(LayerNormTest, LayerNorm_Scale_Bias) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -211,6 +225,21 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); } +TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializers) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), true); + test.AddInput("bias", {2}, ToFloat16({0.6435f, -0.3964f}), true); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); +} + // LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check. TEST(LayerNormTest, LayerNorm17_float) { OpTester test("LayerNormalization", 17); diff --git a/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc index 75ce7b77acd4e..f6158d8cbc12b 100644 --- a/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc +++ b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc @@ -111,9 +111,20 @@ static void BM_LayerNormalization(benchmark::State& state) { OrtMemoryInfo memory_info(onnxruntime::CPU, OrtAllocatorType::OrtArenaAllocator); AllocatorPtr alloc = std::make_shared(memory_info); for (auto _ : state) { - auto status = layer_norm_impl.ComputeWithoutContext(x_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, - Y_data, mean_data, inv_std_dev_data, thread_pool.get(), axis, - epsilon, simplified, alloc); + auto status = layer_norm_impl.ComputeWithoutContext(x_data, + x_shape, + scale_data, + static_cast(scale_shape.Size()), + bias_data, + static_cast(bias_shape.Size()), + Y_data, + mean_data, + inv_std_dev_data, + thread_pool.get(), + axis, + epsilon, + simplified, + alloc); if (!status.IsOK()) { std::cout << "ComputeWithoutContext status not OK: " << status.ErrorMessage() << std::endl; break;