Skip to content

Commit

Permalink
Unify GetOrtEnv()
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanUnderhill committed Apr 24, 2024
1 parent 6c5efe3 commit 5ab249b
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

namespace Generators {

static bool _=(Ort::InitApi(), false);

OrtGlobals::OrtGlobals() : env_{OrtEnv::Create()} {}

OrtGlobals& GetOrtGlobals() {
Expand Down
3 changes: 3 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ struct OrtGlobals {
std::unique_ptr<OrtMemoryInfo> memory_info_cuda_;
std::unique_ptr<Ort::Allocator> allocator_cuda_;
#endif
private:
OrtGlobals(const OrtGlobals&) = delete;
void operator=(const OrtGlobals&) = delete;
};

OrtGlobals& GetOrtGlobals();
Expand Down
3 changes: 0 additions & 3 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options
.def("try_use_cuda_graph_with_max_batch_size", &PyGeneratorParams::TryUseCudaGraphWithMaxBatchSize);

// We need to init the OrtApi before we can use it
Ort::InitApi();

pybind11::class_<TokenizerStream>(m, "TokenizerStream")
.def("decode", [](TokenizerStream& t, int32_t token) { return t.Decode(token); });

Expand Down
2 changes: 1 addition & 1 deletion src/smartptrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ struct cuda_stream_holder {
#else
struct cuda_stream_holder {
void Create() {
assert(false);
throw std::runtime_error("Trying to create a cuda stream in a non cuda build");
}

operator cudaStream_t() const { return v_; }
Expand Down
2 changes: 0 additions & 2 deletions test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ int main(int argc, char** argv) {
std::cout << "Initializing OnnxRuntime... ";
std::cout.flush();
try {
Ort::InitApi();
g_ort_env = OrtEnv::Create();
std::cout << "done" << std::endl;
::testing::InitGoogleTest(&argc, argv);
int result = RUN_ALL_TESTS();
Expand Down
14 changes: 3 additions & 11 deletions test/sampling_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPCuda) {

TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCuda) {
std::unique_ptr<OrtEnv> g_ort_env;
Ort::InitApi();
g_ort_env = OrtEnv::Create();
auto model = Generators::CreateModel(*g_ort_env, MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int vocab_size = 32000; // vocab size of llama
int batch_size = 1;
int k = 5;
Expand Down Expand Up @@ -218,10 +216,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopKCuda) {
}

TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCuda) {
std::unique_ptr<OrtEnv> g_ort_env;
Ort::InitApi();
g_ort_env = OrtEnv::Create();
auto model = Generators::CreateModel(*g_ort_env, MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int vocab_size = 32000; // vocab size of llama
int batch_size = 1;
float p = 0.95f;
Expand Down Expand Up @@ -266,10 +261,7 @@ TEST(Benchmarks, BenchmarkRandomizedSamplingTopPAndKCuda) {
}

TEST(Benchmarks, BenchmarkRandomizedSelectTopCuda) {
std::unique_ptr<OrtEnv> g_ort_env;
Ort::InitApi();
g_ort_env = OrtEnv::Create();
auto model = Generators::CreateModel(*g_ort_env, MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
auto model = Generators::CreateModel(Generators::GetOrtEnv(), MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
int vocab_size = 32000; // vocab size of llama
int batch_size = 12;
std::vector<int32_t> input_ids{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; // Needs to match batch_size
Expand Down

0 comments on commit 5ab249b

Please sign in to comment.