Skip to content

Commit

Permalink
Fix CUDA BatchNorm bugs and add support for NHWC (#19742)
Browse files Browse the repository at this point in the history
### Description
- Fix incorrect running_mean / running_var in training mode due to
incorrect momentum and missing input mean/var. runnig_var could be
correct, but has a too high epsilon.
- Fix incorrect checks when using NHWC
- Pass NHWC flag to NormalizeDims to get correct new dimensions from
x_shape
- Register missing double operations to get parity between NHWC/NCHW
  • Loading branch information
mtavenrath authored Mar 5, 2024
1 parent cd56ea4 commit bdf678d
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 21 deletions.
41 changes: 28 additions & 13 deletions onnxruntime/core/providers/cpu/nn/batch_norm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class BatchNormHelper {
const Tensor* var,
bool is_spatial = true,
bool is_nhwc = false) {
// NHWC dependent shape: X
// All other shapes are assumed to be in NCHW layout?
const auto& x_dims = X->Shape().GetDims();

// If x_dims size < 2, num_channels defaults to 1.
Expand All @@ -48,67 +50,80 @@ class BatchNormHelper {
// validate 'scales' shape
const auto& scale_dims = scale->Shape().GetDims();
if (static_cast<int>(scale_dims.size()) != kNumInputScaleDimensions) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions);
}
if (scale_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels);
}
// N & C do not belong to features
// skip the first element for NHWC and the first two elements for NCHW.
int feature_offset = is_nhwc ? 1 : 2;

// in non-spatial cases - the other dims of 'scale' must be validated
if (!is_spatial) {
for (int feature = 0; feature < num_feature_dims; ++feature) {
if (scale_dims[1 + feature] != x_dims[2 + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
if (scale_dims[1 + feature] != x_dims[feature_offset + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature),
" dimension != ", x_dims[feature_offset + feature]);
}
}
}

// validate 'B' shape
const auto& B_dims = B->Shape().GetDims();
if (static_cast<int>(B_dims.size()) != kNumInputBiasDimensions) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input B: NumDimensions() != ", kNumInputBiasDimensions);
}
if (B_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels);
}
// in non-spatial cases - the other dims of 'B' must be validated
if (!is_spatial) {
for (int feature = 0; feature < num_feature_dims; ++feature) {
if (B_dims[1 + feature] != x_dims[2 + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
if (B_dims[1 + feature] != x_dims[feature_offset + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature),
" dimension != ", x_dims[feature_offset + feature]);
}
}
}

// validate 'mean' shape
const auto& mean_dims = mean->Shape().GetDims();
if (static_cast<int>(mean_dims.size()) != kNumInputMeanDimensions) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions);
}
if (mean_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input mean: 0th dimension != ", num_channels);
}
// in non-spatial cases - the other dims of 'mean' must be validated
if (!is_spatial) {
for (int feature = 0; feature < num_feature_dims; ++feature) {
if (mean_dims[1 + feature] != x_dims[2 + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
if (mean_dims[1 + feature] != x_dims[feature_offset + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature),
" dimension != ", x_dims[feature_offset + feature]);
}
}
}

// validate 'var' shape
const auto& var_dims = var->Shape().GetDims();
if (static_cast<int>(var_dims.size()) != kNumInputVarianceDimensions) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions);
}
if (var_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels);
}
// in non-spatial cases - the other dims of 'var' must be validated
if (!is_spatial) {
for (int feature = 0; feature < num_feature_dims; ++feature) {
if (var_dims[1 + feature] != x_dims[2 + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]);
if (var_dims[1 + feature] != x_dims[feature_offset + feature]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature),
" dimension != ", x_dims[feature_offset + feature]);
}
}
}
Expand Down
18 changes: 12 additions & 6 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1202,9 +1202,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin);
Expand Down Expand Up @@ -2107,9 +2110,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Mul)>,
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ namespace onnxruntime::cuda {

class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float,
Expand Down Expand Up @@ -72,10 +76,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float,
BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double,
BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16,
BatchNormalization);

Expand All @@ -86,18 +94,26 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider,
Expand Down
11 changes: 9 additions & 2 deletions onnxruntime/core/providers/cuda/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)

CudnnTensor data_desc;
vector<int64_t> new_dims;
BatchNormHelper::NormalizeDims(x_shape, new_dims);
BatchNormHelper::NormalizeDims(x_shape, new_dims, NHWC);
ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType<CudaT>(), NHWC));

// For half data type, the alpha, beta, scale, B, mean, var need to be float type
Expand Down Expand Up @@ -137,6 +137,12 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
auto saved_mean_data = reinterpret_cast<CudaT*>(saved_mean->MutableData<T>());
auto saved_inv_var_data = reinterpret_cast<CudaT*>(saved_var->MutableData<T>());

auto stream = static_cast<cudaStream_t>(p_op_kernel_context->GetComputeStream()->GetHandle());
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream));
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(running_var_data, var_data, var->SizeInBytes(), cudaMemcpyDeviceToDevice, stream));

CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper(
GetCudnnHandle(p_op_kernel_context),
cudnn_batch_norm_mode_,
Expand All @@ -149,7 +155,7 @@ Status BatchNorm<T, NHWC>::ComputeInternal(OpKernelContext* p_op_kernel_context)
bn_tensor_desc,
scale_data,
b_data,
momentum_,
1.0 - momentum_,
running_mean_data,
running_var_data,
epsilon_,
Expand Down Expand Up @@ -186,6 +192,7 @@ SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false)

#ifdef ENABLE_CUDA_NHWC_OPS
SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true)
SPECIALIZED_COMPUTE(double, kMSInternalNHWCDomain, true)
SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true)
#endif
} // namespace cuda
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) {
// exclude CUDA Execution Provider due to flakiness
// exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm()
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
// TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1
{kCudaExecutionProvider, kRocmExecutionProvider,
kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider});
}
Expand Down

0 comments on commit bdf678d

Please sign in to comment.