Skip to content

Commit

Permalink
Fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed May 1, 2024
1 parent 2ceaf84 commit df51783
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 10 deletions.
1 change: 1 addition & 0 deletions benchmark/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ void RunBenchmark(const benchmark::Options& opts) {
} // namespace

int main(int argc, char** argv) {
OgaHandle handle;
try {
const auto opts = benchmark::ParseOptionsFromCommandLine(argc, argv);
RunBenchmark(opts);
Expand Down
10 changes: 7 additions & 3 deletions examples/c/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ void CXX_API(const char* model_path) {
auto tokenizer = OgaTokenizer::Create(*model);

const char* prompt = "def is_prime(num):";
std::cout << "Prompt: " << std::endl << prompt << std::endl;
std::cout << "Prompt: " << std::endl
<< prompt << std::endl;

auto sequences = OgaSequences::Create();
tokenizer->Encode(prompt, *sequences);
Expand All @@ -21,14 +22,15 @@ void CXX_API(const char* model_path) {
auto output_sequences = model->Generate(*params);
auto out_string = tokenizer->Decode(output_sequences->Get(0));

std::cout << "Output: " << std::endl << out_string << std::endl;
std::cout << "Output: " << std::endl
<< out_string << std::endl;
}

// C API Example

void CheckResult(OgaResult* result) {
if (result) {
std::string string=OgaResultGetError(result);
std::string string = OgaResultGetError(result);
OgaDestroyResult(result);
throw std::runtime_error(string);
}
Expand Down Expand Up @@ -84,6 +86,8 @@ int main(int argc, char** argv) {
return -1;
}

// Responsible for cleaning up the library during shutdown
OgaHandle handle;

std::cout << "-------------" << std::endl;
std::cout << "Hello, Phi-2!" << std::endl;
Expand Down
2 changes: 2 additions & 0 deletions examples/csharp/HelloPhi2/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.OnnxRuntimeGenAI;

OgaHandle ogaHandle = new OgaHandle();

Console.WriteLine("-------------");
Console.WriteLine("Hello, Phi-2!");
Console.WriteLine("-------------");
Expand Down
3 changes: 3 additions & 0 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,8 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGetCurrentGpuDeviceId(out IntPtr /* int32_t */ device_id);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaShutdown();
}
}
8 changes: 8 additions & 0 deletions src/csharp/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@

namespace Microsoft.ML.OnnxRuntimeGenAI
{
public class OgaHandle
{
~OgaHandle()
{
NativeMethods.OgaShutdown();
}
}

public class Utils
{
public static void SetCurrentGpuDeviceId(int device_id)
Expand Down
7 changes: 7 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,10 @@ struct OgaGenerator : OgaAbstract {

static void operator delete(void* p) { OgaDestroyGenerator(reinterpret_cast<OgaGenerator*>(p)); }
};

struct OgaHandle {
OgaHandle() = default;
~OgaHandle() noexcept {
OgaShutdown();
}
};
5 changes: 1 addition & 4 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@ extern "C" {
return reinterpret_cast<OgaResult*>(std::make_unique<Generators::Result>(e.what()).release()); \
}

OgaResult* OGA_API_CALL OgaShutdown() {
OGA_TRY
void OGA_API_CALL OgaShutdown() {
Generators::Shutdown();
return nullptr;
OGA_CATCH
}

const char* OGA_API_CALL OgaResultGetError(const OgaResult* result) {
Expand Down
4 changes: 1 addition & 3 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ typedef struct OgaTokenizer OgaTokenizer;
typedef struct OgaTokenizerStream OgaTokenizerStream;

/* \brief Call this on process exit to cleanly shutdown the genai library & its onnxruntime usage
* \return Error message contained in the OgaResult. The const char* is owned by the OgaResult
* and can will be freed when the OgaResult is destroyed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaShutdown();
OGA_EXPORT void OGA_API_CALL OgaShutdown();

/*
* \param[in] result OgaResult that contains the error message.
Expand Down

0 comments on commit df51783

Please sign in to comment.