From bdf678df93cb257e311de3fa82fe6409be2854ff Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Tue, 5 Mar 2024 17:09:42 +0100 Subject: [PATCH] Fix CUDA BatchNorm bugs and add support for NHWC (#19742) ### Description - Fix incorrect running_mean / running_var in training mode due to incorrect momentum and missing input mean/var. runnig_var could be correct, but has a too high epsilon. - Fix incorrect checks when using NHWC - Pass NHWC flag to NormalizeDims to get correct new dimensions from x_shape - Register missing double operations to get parity between NHWC/NCHW --- .../core/providers/cpu/nn/batch_norm_helper.h | 41 +++++++++++++------ .../providers/cuda/cuda_execution_provider.cc | 18 +++++--- .../core/providers/cuda/cuda_nhwc_kernels.cc | 16 ++++++++ .../core/providers/cuda/nn/batch_norm.cc | 11 ++++- .../providers/cpu/nn/batch_norm_op_test.cc | 1 + 5 files changed, 66 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h index a5d46aff83b50..ccecbabfa3db3 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h @@ -25,6 +25,8 @@ class BatchNormHelper { const Tensor* var, bool is_spatial = true, bool is_nhwc = false) { + // NHWC dependent shape: X + // All other shapes are assumed to be in NCHW layout? const auto& x_dims = X->Shape().GetDims(); // If x_dims size < 2, num_channels defaults to 1. @@ -48,16 +50,22 @@ class BatchNormHelper { // validate 'scales' shape const auto& scale_dims = scale->Shape().GetDims(); if (static_cast(scale_dims.size()) != kNumInputScaleDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); } if (scale_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels); } + // N & C do not belong to features + // skip the first element for NHWC and the first two elements for NCHW. + int feature_offset = is_nhwc ? 1 : 2; + // in non-spatial cases - the other dims of 'scale' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (scale_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (scale_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -65,7 +73,8 @@ class BatchNormHelper { // validate 'B' shape const auto& B_dims = B->Shape().GetDims(); if (static_cast(B_dims.size()) != kNumInputBiasDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); } if (B_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels); @@ -73,8 +82,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'B' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (B_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (B_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -82,16 +92,19 @@ class BatchNormHelper { // validate 'mean' shape const auto& mean_dims = mean->Shape().GetDims(); if (static_cast(mean_dims.size()) != kNumInputMeanDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); } if (mean_dims[0] != num_channels) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: 0th dimension != ", num_channels); } // in non-spatial cases - the other dims of 'mean' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (mean_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (mean_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -99,7 +112,8 @@ class BatchNormHelper { // validate 'var' shape const auto& var_dims = var->Shape().GetDims(); if (static_cast(var_dims.size()) != kNumInputVarianceDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); } if (var_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels); @@ -107,8 +121,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'var' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (var_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (var_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 1ce089fd93044..8ba282031a5d4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1202,9 +1202,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); @@ -2107,9 +2110,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index f416caecd115f..64edc319e15ac 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -18,10 +18,14 @@ namespace onnxruntime::cuda { class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, @@ -72,10 +76,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double, + BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16, BatchNormalization); @@ -86,18 +94,26 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo::ComputeInternal(OpKernelContext* p_op_kernel_context) CudnnTensor data_desc; vector new_dims; - BatchNormHelper::NormalizeDims(x_shape, new_dims); + BatchNormHelper::NormalizeDims(x_shape, new_dims, NHWC); ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType(), NHWC)); // For half data type, the alpha, beta, scale, B, mean, var need to be float type @@ -137,6 +137,12 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) auto saved_mean_data = reinterpret_cast(saved_mean->MutableData()); auto saved_inv_var_data = reinterpret_cast(saved_var->MutableData()); + auto stream = static_cast(p_op_kernel_context->GetComputeStream()->GetHandle()); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_var_data, var_data, var->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper( GetCudnnHandle(p_op_kernel_context), cudnn_batch_norm_mode_, @@ -149,7 +155,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) bn_tensor_desc, scale_data, b_data, - momentum_, + 1.0 - momentum_, running_mean_data, running_var_data, epsilon_, @@ -186,6 +192,7 @@ SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false) #ifdef ENABLE_CUDA_NHWC_OPS SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true) +SPECIALIZED_COMPUTE(double, kMSInternalNHWCDomain, true) SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true) #endif } // namespace cuda diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index cbb4531a50b7c..54e5c71bd753a 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -916,6 +916,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { // exclude CUDA Execution Provider due to flakiness // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", + // TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1 {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); }