From e753008868ef1e70f57e91dc02730b49c2c6f274 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 18 Mar 2024 06:33:31 +0000 Subject: [PATCH] disable tf32 in cuda ep test --- .../test/contrib_ops/attention_op_test.cc | 14 ------------- .../test/contrib_ops/beam_search_test.cc | 20 ++++++++++++++---- .../test/contrib_ops/greedy_search_test.cc | 14 +++++++++++-- .../contrib_ops/packed_attention_op_test.cc | 3 +-- onnxruntime/test/contrib_ops/sampling_test.cc | 9 +++++++- onnxruntime/test/onnx/main.cc | 14 +++++++++---- onnxruntime/test/providers/cpu/model_tests.cc | 21 ++++++------------- onnxruntime/test/util/default_providers.cc | 6 ++++-- 8 files changed, 57 insertions(+), 44 deletions(-) diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index b652e0723f5aa..7fe70fd2d6f09 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -2013,13 +2013,6 @@ TEST(AttentionTest, AttentionMaskIndexOutOfRange) { #if !defined(__wasm__) // TODO: fix in web assembly TEST(AttentionTest, AttentionPastState_dynamic) { - // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. - // Do not run this test unless TF32 is disabled explicitly. - if (HasCudaEnvironment(800) && ParseEnvironmentVariableWithDefault("NVIDIA_TF32_OVERRIDE", 1) != 0) { - GTEST_SKIP() << "Skipping AttentionPastState_dynamic in A100 since TF32 is enabled"; - return; - } - // create rand inputs RandomValueGenerator random{}; @@ -2101,13 +2094,6 @@ static void RunModelWithRandomInput( std::vector& mask_index_data, std::string& onnx_model, bool is_float16) { - // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. - // Do not run this test unless TF32 is disabled explicitly. - if (HasCudaEnvironment(800) && ParseEnvironmentVariableWithDefault("NVIDIA_TF32_OVERRIDE", 1) != 0) { - GTEST_SKIP() << "Skipping RunModelWithRandomInput in A100 since TF32 is enabled"; - return; - } - RandomValueGenerator random{234}; constexpr int hidden_size = 768; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 156ed3799fc22..6ce9f5de68f11 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -8,6 +8,10 @@ #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" +#ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_options.h" +#endif + extern std::unique_ptr ort_env; namespace onnxruntime { @@ -70,7 +74,9 @@ TEST(BeamSearchTest, GptBeamSearchFp32) { Ort::SessionOptions session_options; #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.use_tf32 = false; + session_options.AppendExecutionProvider_CUDA_V2(cuda_options); #endif #ifdef USE_ROCM @@ -161,7 +167,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16) { if (enable_cuda || enable_rocm) { Ort::SessionOptions session_options; #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.use_tf32 = false; + session_options.AppendExecutionProvider_CUDA_V2(cuda_options); #endif #ifdef USE_ROCM @@ -254,7 +262,9 @@ TEST(BeamSearchTest, GptBeamSearchWithInitDecoderFp16) { if (enable_cuda || enable_rocm) { Ort::SessionOptions session_options; #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.use_tf32 = false; + session_options.AppendExecutionProvider_CUDA_V2(cuda_options); #endif #ifdef USE_ROCM @@ -346,7 +356,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { if (enable_cuda || enable_rocm) { Ort::SessionOptions session_options; #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.use_tf32 = false; + session_options.AppendExecutionProvider_CUDA_V2(cuda_options); #endif #ifdef USE_ROCM diff --git a/onnxruntime/test/contrib_ops/greedy_search_test.cc b/onnxruntime/test/contrib_ops/greedy_search_test.cc index 1baf50c1ba616..8186529f8df45 100644 --- a/onnxruntime/test/contrib_ops/greedy_search_test.cc +++ b/onnxruntime/test/contrib_ops/greedy_search_test.cc @@ -8,6 +8,10 @@ #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" +#ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_options.h" +#endif + extern std::unique_ptr ort_env; namespace onnxruntime { @@ -64,9 +68,13 @@ TEST(GreedySearchTest, GptGreedySearchFp16_VocabPadded) { if (is_cuda || is_rocm) { Ort::SessionOptions session_options; +#ifdef USE_CUDA if (is_cuda) { - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.use_tf32 = false; + session_options.AppendExecutionProvider_CUDA_V2(cuda_options); } +#endif if (is_rocm) { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); } @@ -146,7 +154,9 @@ TEST(GreedySearchTest, GptGreedySearchFp32) { if (is_cuda || is_rocm) { Ort::SessionOptions session_options; if (is_cuda) { - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.use_tf32 = false; + session_options.AppendExecutionProvider_CUDA_V2(cuda_options); } if (is_rocm) { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc index 31ef62e69bb88..09baf8def05f6 100644 --- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc @@ -433,8 +433,7 @@ static void RunModelWithRandomInput( std::vector token_offset_dims{batch_size, sequence_length}; std::vector cum_seq_len_dims{batch_size + 1}; - // TF32 in SM >= 80 is enabled by default, need larger threshold for float when TF32 is enabled. - float gpu_threshold = is_float16 ? 0.15f : (HasCudaEnvironment(800) ? 0.05f : 0.005f); + float gpu_threshold = is_float16 ? 0.15f : 0.005f; gpu_threshold *= sequence_length > 1024 ? 4.0f : 1.0f; // threshold should increase with sequence length bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0); if (enable_cuda) { diff --git a/onnxruntime/test/contrib_ops/sampling_test.cc b/onnxruntime/test/contrib_ops/sampling_test.cc index 733bc9f01fd11..d987a1cae427d 100644 --- a/onnxruntime/test/contrib_ops/sampling_test.cc +++ b/onnxruntime/test/contrib_ops/sampling_test.cc @@ -8,6 +8,10 @@ #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" +#ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_options.h" +#endif + extern std::unique_ptr ort_env; namespace onnxruntime { @@ -65,7 +69,10 @@ TEST(SamplingTest, Gpt2Sampling_GPU) { LOGS_DEFAULT(WARNING) << "Hardware NOT support current architecture"; return; } - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + + OrtCUDAProviderOptionsV2 cuda_options; + cuda_options.use_tf32 = false; + session_options.AppendExecutionProvider_CUDA_V2(cuda_options); #else // USE_ROCM OrtROCMProviderOptions rocm_options; // TODO - verify the default settings diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 9c2c24e3c337d..db706bf929748 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -25,6 +25,10 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "nlohmann/json.hpp" +#ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_options.h" +#endif + using namespace onnxruntime; namespace { @@ -401,12 +405,13 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (enable_tensorrt) { #ifdef USE_TENSORRT - OrtCUDAProviderOptions cuda_options; + OrtCUDAProviderOptionsV2 cuda_options; cuda_options.device_id = device_id; cuda_options.do_copy_in_default_stream = true; + cuda_options.use_tf32 = false; // TODO: Support arena configuration for users of test runner Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(sf, device_id)); - sf.AppendExecutionProvider_CUDA(cuda_options); + sf.AppendExecutionProvider_CUDA_V2(cuda_options); #else fprintf(stderr, "TensorRT is not supported in this build"); return -1; @@ -424,10 +429,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) { } if (enable_cuda) { #ifdef USE_CUDA - OrtCUDAProviderOptions cuda_options; + OrtCUDAProviderOptionsV2 cuda_options; cuda_options.do_copy_in_default_stream = true; + cuda_options.use_tf32 = false; // TODO: Support arena configuration for users of test runner - sf.AppendExecutionProvider_CUDA(cuda_options); + sf.AppendExecutionProvider_CUDA_V2(cuda_options); #else fprintf(stderr, "CUDA is not supported in this build"); return -1; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index af71fe5cf79ae..58df7763786b2 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -42,6 +42,10 @@ #include "core/providers/armnn/armnn_provider_factory.h" #endif +#ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_options.h" +#endif + #include "test/common/cuda_op_test_utils.h" // test infrastructure @@ -98,21 +102,6 @@ TEST_P(ModelTest, Run) { std::unique_ptr model_info = std::make_unique(model_path.c_str()); -#if defined(__linux__) - // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. - if (HasCudaEnvironment(800) && provider_name == "cuda") { - per_sample_tolerance = 1e-1; - if (model_path.find(ORT_TSTR("SSD")) > 0 || - model_path.find(ORT_TSTR("ssd")) > 0 || - model_path.find(ORT_TSTR("yolov3")) > 0 || - model_path.find(ORT_TSTR("mask_rcnn")) > 0 || - model_path.find(ORT_TSTR("FNS")) > 0) { - SkipTest("Skipping SSD test for big tolearance failure or other errors"); - return; - } - } -#endif - if (model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { SkipTest("it has the training domain. No pipeline should need to run these tests."); @@ -198,6 +187,7 @@ TEST_P(ModelTest, Run) { std::string device_id = Env::Default().GetEnvironmentVar("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); values.push_back(device_id.empty() ? "0" : device_id.c_str()); ASSERT_ORT_STATUS_OK(OrtApis::UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), 1)); + cuda_options->use_tf32 = false; ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "rocm") { OrtROCMProviderOptions ep_options; @@ -229,6 +219,7 @@ TEST_P(ModelTest, Run) { ASSERT_ORT_STATUS_OK(OrtApis::CreateCUDAProviderOptions(&cuda_options)); std::unique_ptr rel_cuda_options( cuda_options, &OrtApis::ReleaseCUDAProviderOptions); + cuda_options->use_tf32 = false; ortso.AppendExecutionProvider_CUDA_V2(*cuda_options); } else if (provider_name == "migraphx") { OrtMIGraphXProviderOptions ep_options; diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index c12a52c4356aa..6ad2d41edb562 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -8,7 +8,7 @@ #ifdef USE_COREML #include "core/providers/coreml/coreml_provider_factory.h" #endif -#if defined(ENABLE_CUDA_NHWC_OPS) +#ifdef USE_CUDA #include #endif #include "core/session/onnxruntime_cxx_api.h" @@ -113,8 +113,9 @@ std::unique_ptr DefaultOpenVINOExecutionProvider() { std::unique_ptr DefaultCudaExecutionProvider() { #ifdef USE_CUDA - OrtCUDAProviderOptions provider_options{}; + OrtCUDAProviderOptionsV2 provider_options{}; provider_options.do_copy_in_default_stream = true; + provider_options.use_tf32 = false; if (auto factory = CudaProviderFactoryCreator::Create(&provider_options)) return factory->CreateProvider(); #endif @@ -126,6 +127,7 @@ std::unique_ptr DefaultCudaNHWCExecutionProvider() { #if defined(USE_CUDA) OrtCUDAProviderOptionsV2 provider_options{}; provider_options.do_copy_in_default_stream = true; + provider_options.use_tf32 = false; provider_options.prefer_nhwc = true; if (auto factory = CudaProviderFactoryCreator::Create(&provider_options)) return factory->CreateProvider();