Skip to content

Commit

Permalink
upate test inference
Browse files Browse the repository at this point in the history
  • Loading branch information
mszhanyi committed Jan 23, 2024
1 parent 8f77791 commit 93373c9
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion onnxruntime/test/global_thread_pools/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "test_allocator.h"
#include "../shared_lib/test_fixture.h"
#include <stdlib.h>
#include "test/common/cuda_op_test_utils.h"

struct Input {
const char* name = nullptr;
Expand Down Expand Up @@ -55,9 +56,15 @@ static void RunSession(OrtAllocator& allocator, Ort::Session& session_object,
// size_t total_len = type_info.GetElementCount();
ASSERT_EQ(values_y.size(), static_cast<size_t>(5));

auto tolerance = 1e-6f;
#ifdef USE_CUDA
if (HasCudaEnvironment(800)) {
tolerance = 1e-5f;
}
#endif
OutT* f = output_tensor->GetTensorMutableData<OutT>();
for (size_t i = 0; i != static_cast<size_t>(5); ++i) {
ASSERT_NEAR(values_y[i], f[i], 1e-6f);
ASSERT_NEAR(values_y[i], f[i], tolerance);
}
}

Expand Down

0 comments on commit 93373c9

Please sign in to comment.