diff --git a/benchmark/c/main.cpp b/benchmark/c/main.cpp index 76d25458b..b48dbb507 100644 --- a/benchmark/c/main.cpp +++ b/benchmark/c/main.cpp @@ -231,6 +231,7 @@ void RunBenchmark(const benchmark::Options& opts) { } // namespace int main(int argc, char** argv) { + OgaHandle handle; try { const auto opts = benchmark::ParseOptionsFromCommandLine(argc, argv); RunBenchmark(opts); diff --git a/examples/c/src/main.cpp b/examples/c/src/main.cpp index e4be639f2..f78aa196e 100644 --- a/examples/c/src/main.cpp +++ b/examples/c/src/main.cpp @@ -9,7 +9,8 @@ void CXX_API(const char* model_path) { auto tokenizer = OgaTokenizer::Create(*model); const char* prompt = "def is_prime(num):"; - std::cout << "Prompt: " << std::endl << prompt << std::endl; + std::cout << "Prompt: " << std::endl + << prompt << std::endl; auto sequences = OgaSequences::Create(); tokenizer->Encode(prompt, *sequences); @@ -21,14 +22,15 @@ void CXX_API(const char* model_path) { auto output_sequences = model->Generate(*params); auto out_string = tokenizer->Decode(output_sequences->Get(0)); - std::cout << "Output: " << std::endl << out_string << std::endl; + std::cout << "Output: " << std::endl + << out_string << std::endl; } // C API Example void CheckResult(OgaResult* result) { if (result) { - std::string string=OgaResultGetError(result); + std::string string = OgaResultGetError(result); OgaDestroyResult(result); throw std::runtime_error(string); } @@ -84,6 +86,8 @@ int main(int argc, char** argv) { return -1; } + // Responsible for cleaning up the library during shutdown + OgaHandle handle; std::cout << "-------------" << std::endl; std::cout << "Hello, Phi-2!" << std::endl; diff --git a/examples/csharp/HelloPhi2/Program.cs b/examples/csharp/HelloPhi2/Program.cs index 993af8b57..fecb24ad7 100644 --- a/examples/csharp/HelloPhi2/Program.cs +++ b/examples/csharp/HelloPhi2/Program.cs @@ -1,6 +1,8 @@ // See https://aka.ms/new-console-template for more information using Microsoft.ML.OnnxRuntimeGenAI; +OgaHandle ogaHandle = new OgaHandle(); + Console.WriteLine("-------------"); Console.WriteLine("Hello, Phi-2!"); Console.WriteLine("-------------"); diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index f2906f3df..3f6960549 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -176,5 +176,8 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OgaResult* */ OgaGetCurrentGpuDeviceId(out IntPtr /* int32_t */ device_id); + + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern void OgaShutdown(); } } diff --git a/src/csharp/Utils.cs b/src/csharp/Utils.cs index 815652a71..2a8723280 100644 --- a/src/csharp/Utils.cs +++ b/src/csharp/Utils.cs @@ -7,6 +7,14 @@ namespace Microsoft.ML.OnnxRuntimeGenAI { + public class OgaHandle + { + ~OgaHandle() + { + NativeMethods.OgaShutdown(); + } + } + public class Utils { public static void SetCurrentGpuDeviceId(int device_id) diff --git a/src/ort_genai.h b/src/ort_genai.h index fb863dae2..5ef23ecfa 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -235,3 +235,10 @@ struct OgaGenerator : OgaAbstract { static void operator delete(void* p) { OgaDestroyGenerator(reinterpret_cast(p)); } }; + +struct OgaHandle { + OgaHandle() = default; + ~OgaHandle() noexcept { + OgaShutdown(); + } +}; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 13cae5235..46bc08f84 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -29,11 +29,8 @@ extern "C" { return reinterpret_cast(std::make_unique(e.what()).release()); \ } -OgaResult* OGA_API_CALL OgaShutdown() { - OGA_TRY +void OGA_API_CALL OgaShutdown() { Generators::Shutdown(); - return nullptr; - OGA_CATCH } const char* OGA_API_CALL OgaResultGetError(const OgaResult* result) { diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 0939d2c36..412b44bcd 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -41,10 +41,8 @@ typedef struct OgaTokenizer OgaTokenizer; typedef struct OgaTokenizerStream OgaTokenizerStream; /* \brief Call this on process exit to cleanly shutdown the genai library & its onnxruntime usage - * \return Error message contained in the OgaResult. The const char* is owned by the OgaResult - * and can will be freed when the OgaResult is destroyed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaShutdown(); +OGA_EXPORT void OGA_API_CALL OgaShutdown(); /* * \param[in] result OgaResult that contains the error message.