Skip to content

Commit

Permalink
[CUDA] Add use_tf32 cuda provider option (for FP32 Conv) (#19426)
Browse files Browse the repository at this point in the history
Follow up of #19357 to apply the use_tf32 option on fp32 cuDNN convolution.

When use_tf32 = 0, we will disable TF32 in cuDNN convolution for FP32 inputs.

https://docs.nvidia.com/deeplearning/cudnn/api/cudnn-graph-library.html#cudnnmathtype-t
**CUDNN_FMA_MATH**
- Restricted to only kernels that use FMA instructions.
- On pre-NVIDIA A100 GPU devices, CUDNN_DEFAULT_MATH and CUDNN_FMA_MATH
have the same behavior: Tensor Core kernels will not be selected.
- With NVIDIA Ampere architecture and CUDA toolkit 11,
CUDNN_DEFAULT_MATH permits TF32 Tensor Core operation and CUDNN_FMA_MATH
does not.
- The TF32 behavior for CUDNN_DEFAULT_MATH and the other Tensor Core
math types can be explicitly disabled by the environment variable
NVIDIA_TF32_OVERRIDE=0.
  • Loading branch information
tianleiwu authored Feb 21, 2024
1 parent e5ce81a commit 3afb38c
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 12 deletions.
17 changes: 14 additions & 3 deletions onnxruntime/core/providers/cuda/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)

ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
gsl::narrow_cast<int>(conv_attrs_.group),
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>()));
CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType<CudaT>(),
UseTF32()));

if (context->InputCount() >= 3) {
const Tensor* B = context->Input<Tensor>(2);
Expand All @@ -351,8 +352,13 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)

if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) {
// set math type to tensor core before algorithm search
if constexpr (std::is_same<T, MLFloat16>::value)
if constexpr (std::is_same<T, MLFloat16>::value) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
} else if constexpr (std::is_same<T, float>::value) {
if (!UseTF32()) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
}
}

cudnnConvolutionFwdAlgoPerf_t perf;
int algo_count = 1;
Expand Down Expand Up @@ -399,6 +405,8 @@ Status Conv<T, NHWC>::UpdateState(OpKernelContext* context, bool bias_expected)
CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory));
if (std::is_same<T, MLFloat16>::value) {
perf.mathType = CUDNN_TENSOR_OP_MATH;
} else if (std::is_same<T, float>::value && !UseTF32()) {
perf.mathType = CUDNN_FMA_MATH;
} else {
perf.mathType = CUDNN_DEFAULT_MATH;
}
Expand Down Expand Up @@ -480,7 +488,8 @@ Status CudnnConvolutionDescriptor::Set(
const gsl::span<const int64_t>& dilations,
int groups,
cudnnConvolutionMode_t mode,
cudnnDataType_t data_type) {
cudnnDataType_t data_type,
bool use_tf32) {
if (!desc_)
CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_));

Expand Down Expand Up @@ -513,6 +522,8 @@ Status CudnnConvolutionDescriptor::Set(
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH));
if (data_type == CUDNN_DATA_HALF) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH));
} else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH));
}

return Status::OK();
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/nn/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final {
const gsl::span<const int64_t>& dilations,
int groups,
cudnnConvolutionMode_t mode,
cudnnDataType_t data_type);
cudnnDataType_t data_type,
bool use_tf32);

operator cudnnConvolutionDescriptor_t() const { return desc_; }

Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/core/providers/cuda/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations,
gsl::narrow_cast<int>(conv_transpose_attrs_.group), mode,
CudnnTensor::GetDataType<CudaT>()));
CudnnTensor::GetDataType<CudaT>(),
UseTF32()));

if (has_bias) {
const auto& b_shape = p.B->Shape();
Expand All @@ -187,8 +188,13 @@ Status ConvTranspose<T, NHWC>::DoConvTranspose(OpKernelContext* context, bool dy
GetScratchBuffer<void>(AlgoSearchWorkspaceSize, context->GetComputeStream());

// set math type to tensor core before algorithm search
if constexpr (std::is_same<T, MLFloat16>::value)
if constexpr (std::is_same<T, MLFloat16>::value) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH));
} else if constexpr (std::is_same<T, float>::value) {
if (!UseTF32()) {
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH));
}
}

cudnnConvolutionBwdDataAlgoPerf_t perf;
int algo_count = 1;
Expand Down
3 changes: 2 additions & 1 deletion orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ Status ConvGrad<T>::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor&
ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type));
ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
args_.params.data_type));
args_.params.data_type,
UseTF32()));

if (dB) {
const TensorShape& db_shape = dB->Shape();
Expand Down
6 changes: 4 additions & 2 deletions orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const
}

template <typename T_Perf>
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results) {
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32) {
perf_results.resize(1);
perf_results[0].algo = AlgoSearch<T_Perf>::DEFAULT_ALGO;
if (args.params.data_type == CUDNN_DATA_HALF) {
perf_results[0].mathType = CUDNN_TENSOR_OP_MATH;
} else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) {
perf_results[0].mathType = CUDNN_FMA_MATH;
} else {
perf_results[0].mathType = CUDNN_DEFAULT_MATH;
}
Expand All @@ -256,7 +258,7 @@ Status AlgoIterator<T_Perf>::TryAll(const CUDAExecutionProvider* provider, const

std::vector<T_Perf> perf_results;
ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault
? OnlyDefaultAlgorithm(args_, perf_results)
? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32())
: AlgoSearch<T_Perf>::FindAlgorithms(args_, provider, allocator, perf_results));
for (auto& algo_perf : perf_results) {
if (f(algo_perf) == Status::OK()) {
Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/training_ops/cuda/nn/conv_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AlgoIterator {
Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
std::function<Status(const T_Perf& perf)> f);

static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results);
static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results, bool use_tf32);

private:
const ConvArgs& args_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ Status ConvTransposeGrad<T>::PrepareConvForwardArgs(const Tensor& X, const Tenso
ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
args.params.data_type));
args.params.data_type,
UseTF32()));
}

return Status::OK();
Expand Down Expand Up @@ -287,7 +288,8 @@ Status ConvTransposeGrad<T>::PrepareConvBackwardFilterArgs(const Tensor& X, cons
ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
args.params.data_type));
args.params.data_type,
UseTF32()));

if (dB) {
const auto& b_shape = dB->Shape();
Expand Down

0 comments on commit 3afb38c

Please sign in to comment.