Skip to content

Commit

Permalink
Add a Shutdown() method
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanUnderhill committed Apr 24, 2024
1 parent 27deb2f commit 1859cad
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 6 deletions.
10 changes: 7 additions & 3 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ static bool _ = (Ort::InitApi(), false);

OrtGlobals::OrtGlobals() : env_{OrtEnv::Create()} {}

OrtGlobals& GetOrtGlobals() {
std::unique_ptr<OrtGlobals>& GetOrtGlobals() {
static auto globals = std::make_unique<OrtGlobals>();
return *globals;
return globals;
}

void Shutdown() {
GetOrtGlobals().reset();
}

OrtEnv& GetOrtEnv() {
return *GetOrtGlobals().env_;
return *GetOrtGlobals()->env_;
}

// IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
Expand Down
3 changes: 2 additions & 1 deletion src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ struct OrtGlobals {
void operator=(const OrtGlobals&) = delete;
};

OrtGlobals& GetOrtGlobals();
std::unique_ptr<OrtGlobals>& GetOrtGlobals();
void Shutdown(); // Do this once at exit, Ort code will fail after this call
OrtEnv& GetOrtEnv();

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path);
Expand Down
2 changes: 1 addition & 1 deletion src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ std::vector<std::string> Tokenizer::DecodeBatch(std::span<const int32_t> sequenc
// has been destroyed. Without this, we will crash in the Onnxruntime BFCArena code when deleting tensors due to the
// arena already being destroyed.
Ort::Allocator* GetCudaAllocator(OrtSession& session) {
auto& globals = GetOrtGlobals();
auto& globals = *GetOrtGlobals();
if (!globals.allocator_cuda_) {
globals.memory_info_cuda_ = OrtMemoryInfo::Create("Cuda", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
globals.allocator_cuda_ = Ort::Allocator::Create(session, *globals.memory_info_cuda_);
Expand Down
7 changes: 7 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ extern "C" {
return reinterpret_cast<OgaResult*>(std::make_unique<Generators::Result>(e.what()).release()); \
}

OgaResult* OGA_API_CALL OgaShutdown() {
OGA_TRY
Generators::Shutdown();
return nullptr;
OGA_CATCH
}

const char* OGA_API_CALL OgaResultGetError(const OgaResult* result) {
return reinterpret_cast<const Generators::Result*>(result)->what_.c_str();
}
Expand Down
6 changes: 6 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ typedef struct OgaSequences OgaSequences;
typedef struct OgaTokenizer OgaTokenizer;
typedef struct OgaTokenizerStream OgaTokenizerStream;

/* \brief Call this on process exit to cleanly shutdown the genai library & its onnxruntime usage
* \return Error message contained in the OgaResult. The const char* is owned by the OgaResult
* and can will be freed when the OgaResult is destroyed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaShutdown();

/*
* \param[in] result OgaResult that contains the error message.
* \return Error message contained in the OgaResult. The const char* is owned by the OgaResult
Expand Down
7 changes: 7 additions & 0 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
)pbdoc";

// Add a cleanup call to happen before global variables are destroyed
static int unused{}; // The capsule needs something to reference
pybind11::capsule cleanup(&unused, [](PyObject*) {
Generators::Shutdown();
});
m.add_object("_cleanup", cleanup);

// So that python users can catch OrtExceptions specifically
pybind11::register_exception<Ort::Exception>(m, "OrtException");

Expand Down
2 changes: 2 additions & 0 deletions test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ int main(int argc, char** argv) {
std::cout << "done" << std::endl;
::testing::InitGoogleTest(&argc, argv);
int result = RUN_ALL_TESTS();
std::cout << "Shutting down OnnxRuntime... ";
Generators::Shutdown();
std::cout << "done" << std::endl;
return result;
} catch (const std::exception& e) {
Expand Down
1 change: 0 additions & 1 deletion test/python/test_onnxruntime_genai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def test_tokenizer_stream(device, phi2_for):
sysconfig.get_platform().endswith("arm64") or sys.version_info.minor < 8,
reason="Python 3.8 is required for downloading models.",
)
@pytest.mark.skip(reason="skipping to test memory issue")
@pytest.mark.parametrize("device", devices)
def test_batching(device, phi2_for):
model = og.Model(phi2_for(device))
Expand Down

0 comments on commit 1859cad

Please sign in to comment.