Skip to content

Commit

Permalink
adjust default test tolerance; disable tf32; fix A100 failed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Mar 18, 2024
1 parent 28ad6c3 commit a16da8a
Show file tree
Hide file tree
Showing 34 changed files with 340 additions and 182 deletions.
20 changes: 6 additions & 14 deletions onnxruntime/test/contrib_ops/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ static void RunAttentionTest(
tester.AddOptionalInputEdge<int32_t>();
}

if (use_float16) {
tester.SetOutputTolerance(0.005f, 0.005f);
} else {
tester.SetOutputTolerance(0.001f, 0.001f);
}

if (enable_cuda) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
Expand Down Expand Up @@ -2013,13 +2019,6 @@ TEST(AttentionTest, AttentionMaskIndexOutOfRange) {
#if !defined(__wasm__)
// TODO: fix in web assembly
TEST(AttentionTest, AttentionPastState_dynamic) {
// ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test.
// Do not run this test unless TF32 is disabled explicitly.
if (HasCudaEnvironment(800) && ParseEnvironmentVariableWithDefault<int>("NVIDIA_TF32_OVERRIDE", 1) != 0) {
GTEST_SKIP() << "Skipping AttentionPastState_dynamic in A100 since TF32 is enabled";
return;
}

// create rand inputs
RandomValueGenerator random{};

Expand Down Expand Up @@ -2101,13 +2100,6 @@ static void RunModelWithRandomInput(
std::vector<int32_t>& mask_index_data,
std::string& onnx_model,
bool is_float16) {
// ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test.
// Do not run this test unless TF32 is disabled explicitly.
if (HasCudaEnvironment(800) && ParseEnvironmentVariableWithDefault<int>("NVIDIA_TF32_OVERRIDE", 1) != 0) {
GTEST_SKIP() << "Skipping RunModelWithRandomInput in A100 since TF32 is enabled";
return;
}

RandomValueGenerator random{234};

constexpr int hidden_size = 768;
Expand Down
20 changes: 16 additions & 4 deletions onnxruntime/test/contrib_ops/beam_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "test/common/cuda_op_test_utils.h"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_options.h"
#endif

extern std::unique_ptr<Ort::Env> ort_env;

namespace onnxruntime {
Expand Down Expand Up @@ -70,7 +74,9 @@ TEST(BeamSearchTest, GptBeamSearchFp32) {

Ort::SessionOptions session_options;
#ifdef USE_CUDA
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
OrtCUDAProviderOptionsV2 cuda_options;
cuda_options.use_tf32 = false;
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
#endif

#ifdef USE_ROCM
Expand Down Expand Up @@ -161,7 +167,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16) {
if (enable_cuda || enable_rocm) {
Ort::SessionOptions session_options;
#ifdef USE_CUDA
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
OrtCUDAProviderOptionsV2 cuda_options;
cuda_options.use_tf32 = false;
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
#endif

#ifdef USE_ROCM
Expand Down Expand Up @@ -254,7 +262,9 @@ TEST(BeamSearchTest, GptBeamSearchWithInitDecoderFp16) {
if (enable_cuda || enable_rocm) {
Ort::SessionOptions session_options;
#ifdef USE_CUDA
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
OrtCUDAProviderOptionsV2 cuda_options;
cuda_options.use_tf32 = false;
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
#endif

#ifdef USE_ROCM
Expand Down Expand Up @@ -346,7 +356,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) {
if (enable_cuda || enable_rocm) {
Ort::SessionOptions session_options;
#ifdef USE_CUDA
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
OrtCUDAProviderOptionsV2 cuda_options;
cuda_options.use_tf32 = false;
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
#endif

#ifdef USE_ROCM
Expand Down
7 changes: 3 additions & 4 deletions onnxruntime/test/contrib_ops/decoder_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ static void RunAttentionTest(
const std::vector<float>* new_value_cache = nullptr,
const std::vector<float>* key_cache = nullptr,
const std::vector<float>* value_cache = nullptr,
const std::initializer_list<bool>* key_padding_mask_data = nullptr,
bool use_float16 = false) {
int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
const std::initializer_list<bool>* key_padding_mask_data = nullptr) {
bool enable_cuda = HasCudaEnvironment(0);
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
bool enable_cpu = false;

Expand Down Expand Up @@ -99,6 +97,7 @@ static void RunAttentionTest(
tester.AddOutput<float>("new_key_cache", output_cache_dims, *new_key_cache);
tester.AddOutput<float>("new_value_cache", output_cache_dims, *new_value_cache);
}
tester.SetOutputTolerance(0.001f, 0.001f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -754,9 +754,10 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) {

// Output(s)
tester.AddOutput<float>("output", input_dims, output);

tester.AddOutput<float>("present", past_dims, present);

tester.SetOutputTolerance(0.001f, 0.001f);

// Run - Regular kernel execution path
{
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
Expand Down Expand Up @@ -897,9 +898,10 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {

// Output(s)
tester.AddOutput<MLFloat16>("output", input_dims, output);

tester.AddOutput<MLFloat16>("present", past_dims, present);

tester.SetOutputTolerance(0.005f, 0.001f);

// Run - Regular kernel execution path
{
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/contrib_ops/fft_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ TEST(ContribOpTest, Rfft) {
// Target values conputed using PyTorch torch.fft.rfft(X, dim=-1, norm="backward")
test.AddInput<float>("X", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
test.AddOutput<float>("Y", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
test.SetOutputTolerance(0.0001f, 0.0001f);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

Expand All @@ -45,6 +46,7 @@ TEST(ContribOpTest, Irfft) {
test.AddAttribute("normalized", static_cast<int64_t>(0));
test.AddInput<float>("X", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f});
test.AddOutput<float>("Y", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f});
test.SetOutputTolerance(0.0001f, 0.0001f);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
} // namespace test
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/test/contrib_ops/greedy_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "test/common/cuda_op_test_utils.h"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_options.h"
#endif

extern std::unique_ptr<Ort::Env> ort_env;

namespace onnxruntime {
Expand Down Expand Up @@ -64,9 +68,13 @@ TEST(GreedySearchTest, GptGreedySearchFp16_VocabPadded) {

if (is_cuda || is_rocm) {
Ort::SessionOptions session_options;
#ifdef USE_CUDA
if (is_cuda) {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
OrtCUDAProviderOptionsV2 cuda_options;
cuda_options.use_tf32 = false;
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
}
#endif
if (is_rocm) {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0));
}
Expand Down Expand Up @@ -146,7 +154,9 @@ TEST(GreedySearchTest, GptGreedySearchFp32) {
if (is_cuda || is_rocm) {
Ort::SessionOptions session_options;
if (is_cuda) {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
OrtCUDAProviderOptionsV2 cuda_options;
cuda_options.use_tf32 = false;
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
}
if (is_rocm) {
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0));
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/contrib_ops/gridsample_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ TEST(GridsampleContribOpTest, gridsample_mode_bicubic) {
0.5000f, 0.5000f, 1.0000f, 1.0000f});
test.AddAttribute("mode", "bicubic");
test.AddOutput<float>("Y", {1, 1, 2, 4}, {-0.1406f, 0.3828f, 1.7556f, 2.9688f, 2.9688f, 1.7556f, 5.1445f, 1.3906f});
test.SetOutputTolerance(0.0001f, 0.0001f);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider});
}

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/test/contrib_ops/layer_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias) {
test.AddInput<float>("gamma", {2}, {-0.6953f, 5.1824f});
test.AddInput<float>("bias", {2}, {0.6435f, -0.3964f});
test.AddOutput<float>("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f});
test.SetOutputTolerance(0.0001f, 0.0001f);
test.Run();
}

Expand All @@ -172,6 +173,8 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) {
test.AddInput<float>("gamma", {2}, {-0.6953f, 5.1824f});
test.AddInput<float>("bias", {2}, {0.6435f, -0.3964f});
test.AddOutput<float>("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f});
test.SetOutputTolerance(0.0001f, 0.0001f);

// TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kTensorrtExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
Expand Down Expand Up @@ -228,6 +231,9 @@ TEST(LayerNormTest, LayerNorm17_double) {
test.AddInput<double>("x", dims, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
test.AddInput<double>("gamma", {3}, {1.0, 1.0, 1.0});
test.AddOutput<double>("output", dims, {-1.2247, 0.0, 1.2247, -1.2247, 0.0, 1.2247});

test.SetOutputTolerance(0.0001f, 0.0001f);

// DNNL does not support double
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDnnlExecutionProvider});
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/contrib_ops/moe_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ static void RunMoETest(
tester.AddInput<MLFloat16>("fc1_experts_bias", fc1_experts_bias_dims, ToFloat16(fc1_experts_bias));
tester.AddInput<MLFloat16>("fc2_experts_bias", fc2_experts_bias_dims, ToFloat16(fc2_experts_bias));
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
tester.SetOutputTolerance(0.005f, 0.005f);
} else {
tester.AddInput<float>("input", input_dims, input);
tester.AddInput<float>("router_probs", router_probs_dims, router_probs);
Expand All @@ -55,6 +56,7 @@ static void RunMoETest(
tester.AddInput<float>("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias);
tester.AddInput<float>("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias);
tester.AddOutput<float>("output", output_dims, output_data);
tester.SetOutputTolerance(0.001f, 0.001f);
}

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/test/contrib_ops/packed_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,7 @@ static void RunModelWithRandomInput(
std::vector<int64_t> token_offset_dims{batch_size, sequence_length};
std::vector<int64_t> cum_seq_len_dims{batch_size + 1};

// TF32 in SM >= 80 is enabled by default, need larger threshold for float when TF32 is enabled.
float gpu_threshold = is_float16 ? 0.15f : (HasCudaEnvironment(800) ? 0.05f : 0.005f);
float gpu_threshold = is_float16 ? 0.15f : 0.005f;
gpu_threshold *= sequence_length > 1024 ? 4.0f : 1.0f; // threshold should increase with sequence length
bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0);
if (enable_cuda) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ static void RunPackedMultiHeadAttentionTest(
}

tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
tester.SetOutputTolerance(0.005f, 0.005f);
} else {
if (is_packed_qkv) {
tester.AddInput<float>("query", packed_qkv_dims, query_data);
Expand All @@ -131,6 +132,7 @@ static void RunPackedMultiHeadAttentionTest(
}

tester.AddOutput<float>("output", output_dims, output_data);
tester.SetOutputTolerance(0.001f, 0.001f);
}

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,13 @@ void RunQAttention(const std::vector<float>& input_data,
tester.AddInput<MLFloat16>("input_scale", {1}, ToFloat16({input_quant_params.scale}));
tester.AddInput<MLFloat16>("weight_scale", {1}, ToFloat16({weight_quant_params.scale}));
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
tester.SetOutputTolerance(0.01f, 0.01f);
} else {
tester.AddInput<float>("bias", bias_dims, bias_data);
tester.AddInput<float>("input_scale", {1}, {input_quant_params.scale});
tester.AddInput<float>("weight_scale", {1}, {weight_quant_params.scale});
tester.AddOutput<float>("output", output_dims, output_data);
tester.SetOutputTolerance(0.005f, 0.005f);
}

if (mask_index_data.size() > 0) {
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/test/contrib_ops/sampling_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "test/common/cuda_op_test_utils.h"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_options.h"
#endif

extern std::unique_ptr<Ort::Env> ort_env;

namespace onnxruntime {
Expand Down Expand Up @@ -65,7 +69,10 @@ TEST(SamplingTest, Gpt2Sampling_GPU) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support current architecture";
return;
}
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));

OrtCUDAProviderOptionsV2 cuda_options;
cuda_options.use_tf32 = false;
session_options.AppendExecutionProvider_CUDA_V2(cuda_options);
#else // USE_ROCM
OrtROCMProviderOptions rocm_options;
// TODO - verify the default settings
Expand Down
14 changes: 10 additions & 4 deletions onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "nlohmann/json.hpp"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_options.h"
#endif

using namespace onnxruntime;

namespace {
Expand Down Expand Up @@ -401,12 +405,13 @@ int real_main(int argc, char* argv[], Ort::Env& env) {

if (enable_tensorrt) {
#ifdef USE_TENSORRT
OrtCUDAProviderOptions cuda_options;
OrtCUDAProviderOptionsV2 cuda_options;
cuda_options.device_id = device_id;
cuda_options.do_copy_in_default_stream = true;
cuda_options.use_tf32 = false;
// TODO: Support arena configuration for users of test runner
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(sf, device_id));
sf.AppendExecutionProvider_CUDA(cuda_options);
sf.AppendExecutionProvider_CUDA_V2(cuda_options);
#else
fprintf(stderr, "TensorRT is not supported in this build");
return -1;
Expand All @@ -424,10 +429,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_cuda) {
#ifdef USE_CUDA
OrtCUDAProviderOptions cuda_options;
OrtCUDAProviderOptionsV2 cuda_options;
cuda_options.do_copy_in_default_stream = true;
cuda_options.use_tf32 = false;
// TODO: Support arena configuration for users of test runner
sf.AppendExecutionProvider_CUDA(cuda_options);
sf.AppendExecutionProvider_CUDA_V2(cuda_options);
#else
fprintf(stderr, "CUDA is not supported in this build");
return -1;
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/test/providers/base_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ void BaseTester::SetOutputRelErr(const char* name, float v) {
it->validation_params.relative_error = optional<float>(v);
}

void BaseTester::SetOutputTolerance(float abs_error, float rel_error) {
for (auto& output : output_data_) {
if (output.def.Exists()) {
output.validation_params.absolute_error = optional<float>(abs_error);
output.validation_params.relative_error = optional<float>(rel_error);
}
}
}

std::vector<int64_t> BaseTester::GetDimsForProto(gsl::span<const int64_t> dims) {
std::vector<int64_t> dims_for_proto{dims.begin(), dims.end()};
if (add_symbolic_dim_to_tensor_data_ >= 0 &&
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/test/providers/base_tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,9 @@ class BaseTester {
void SetOutputAbsErr(const char* name, float v);
void SetOutputRelErr(const char* name, float v);

// Set absolute and relative error for added outputs.
void SetOutputTolerance(float abs_error, float rel_error);

// Number of times to call InferenceSession::Run. The same feeds are used each time.
// e.g. used to verify the generator ops behave as expected
void SetNumRunCalls(int n) {
Expand Down
Loading

0 comments on commit a16da8a

Please sign in to comment.