Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jan 26, 2024
1 parent d2cb637 commit bc07321
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions onnxruntime/test/unittest_main/test_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
#include "gtest/gtest.h"

#include "core/platform/env_var_utils.h"
#include "core/common/logging/sinks/clog_sink.h"
#include "core/session/environment.h"

#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/ort_env.h"
#include "core/util/thread_utils.h"
Expand Down Expand Up @@ -56,11 +59,21 @@ auto const placeholder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInf
std::unique_ptr<Ort::Env> ort_env;
static onnxruntime::Status ortenv_setup() {
OrtThreadingOptions tpo;
OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, "Default"};

onnxruntime::Status status;
OrtEnv* out = OrtEnv::GetInstance(lm_info, status, &tpo);
if (!status.IsOK()) return status;
ort_env = std::make_unique<Ort::Env>(out);

std::unique_ptr<onnxruntime::logging::LoggingManager> lmgr;
std::string name = "Default";

lmgr = std::make_unique<onnxruntime::logging::LoggingManager>(std::make_unique<onnxruntime::logging::CLogSink>(),
onnxruntime::logging::Severity::kWARNING,
false,
onnxruntime::logging::LoggingManager::InstanceType::Default,
&name);

std::unique_ptr<onnxruntime::Environment> env;
ORT_RETURN_IF_ERROR(onnxruntime::Environment::Create(std::move(lmgr), env, &tpo, true));
ort_env = std::make_unique<Ort::Env>(env.release());
return status;
}

Expand All @@ -86,10 +99,6 @@ void MainLoop(void* arg) {
printf("stage %d ...\n", (int)state.stage);
onnxruntime::Status status;
switch (state.stage) {
case EmStage::SETUP_ENV:
::testing::InitGoogleTest(&state.argc, state.argv);
state.stage = EmStage::RUN_TEST;
break;
case EmStage::INIT: {
status = ortenv_setup();
if (!status.IsOK()) {
Expand All @@ -99,6 +108,10 @@ void MainLoop(void* arg) {
state.stage = EmStage::SETUP_ENV;
}
} break;
case EmStage::SETUP_ENV:
::testing::InitGoogleTest(&state.argc, state.argv);
state.stage = EmStage::RUN_TEST;
break;
case EmStage::RUN_TEST:
state.ret = RUN_ALL_TESTS();
state.stage = EmStage::FINI;
Expand All @@ -112,12 +125,12 @@ void MainLoop(void* arg) {
return;
}

static EmState em_global_state;
int TEST_MAIN(int argc, char** argv) {
std::cout << "start: argc=" << argc << std::endl;
EmState s;
s.argc = argc;
s.argv = argv;
emscripten_set_main_loop_arg(MainLoop, &s, 0, 0);
em_global_state.argc = argc;
em_global_state.argv = argv;
emscripten_set_main_loop_arg(MainLoop, &em_global_state, 0, 0);
return s.ret;
}
#else
Expand Down

0 comments on commit bc07321

Please sign in to comment.