Skip to content

Commit

Permalink
Working initial version of C API (#55)
Browse files Browse the repository at this point in the history
* Working C API + C API Test
  • Loading branch information
RyanUnderhill authored Feb 3, 2024
1 parent 035f384 commit 9f0765b
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 68 deletions.
10 changes: 3 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,8 @@ endif()
file(GLOB onnxruntime_libs "${CMAKE_SOURCE_DIR}/ort/${ONNXRUNTIME_FILES}")
foreach(DLL_FILE ${onnxruntime_libs})
add_custom_command(
TARGET onnxruntime-genai POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${DLL_FILE}
${CMAKE_BINARY_DIR}/${DLL_FILE_NAME}
TARGET onnxruntime-genai POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${DLL_FILE} ${CMAKE_BINARY_DIR}/$<CONFIG>/${DLL_FILE_NAME}
)
endforeach()

Expand All @@ -172,9 +170,7 @@ if(BUILD_WHEEL)
foreach(DLL_FILE ${onnxruntime_libs})
add_custom_command(
TARGET onnxruntime-genai
COMMAND ${CMAKE_COMMAND} -E copy
${DLL_FILE}
${WHEEL_FILES_DIR}/${TARGET_NAME}/${DLL_FILE_NAME}
COMMAND ${CMAKE_COMMAND} -E copy ${DLL_FILE} ${WHEEL_FILES_DIR}/${TARGET_NAME}/${DLL_FILE_NAME}
)
endforeach()

Expand Down
112 changes: 91 additions & 21 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstddef>
#include "generators.h"
#include "models/model.h"
#include "search.h"

namespace Generators {

Expand All @@ -24,35 +25,104 @@ OrtEnv& GetOrtEnv() {

extern "C" {

#define OGA_TRY try {
#define OGA_CATCH \
} \
catch (const std::exception& e) { \
return new OgaResult{e.what()}; \
}

struct OgaResult {
explicit OgaResult(const char* what) {}
// TODO: implement this constructor !!!!
explicit OgaResult(const char* what) : what_{what} {}
std::string what_;
};

OgaResult* OgaCreateModel(const char* config_path, OgaDeviceType device_type, OgaModel** out) {
try {
auto provider_options = Generators::GetDefaultProviderOptions(static_cast<Generators::DeviceType>(device_type));
*out = reinterpret_cast<OgaModel*>(Generators::CreateModel(Generators::GetOrtEnv(), config_path, &provider_options).release());
return nullptr;
} catch (const std::exception& e) {
return new OgaResult{e.what()};
}
const char* OGA_API_CALL OgaResultGetError(OgaResult* result) {
return result->what_.c_str();
}

void OgaDestroyModel(OgaModel* model) {
delete reinterpret_cast<Generators::Model*>(model);
OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaDeviceType device_type, OgaModel** out) {
OGA_TRY
auto provider_options = Generators::GetDefaultProviderOptions(static_cast<Generators::DeviceType>(device_type));
*out = reinterpret_cast<OgaModel*>(Generators::CreateModel(Generators::GetOrtEnv(), config_path, &provider_options).release());
return nullptr;
OGA_CATCH
}

OgaResult* OgaCreateState(OgaModel* model, int32_t* sequence_lengths, size_t sequence_lengths_count, const OgaSearchParams* search_params, OgaState** out) {
try {
*out = reinterpret_cast<OgaState*>(reinterpret_cast<Generators::Model*>(model)->CreateState(Generators::cpu_span<int32_t>{sequence_lengths, sequence_lengths_count}, *reinterpret_cast<const Generators::GeneratorParams*>(search_params)).release());
return nullptr;
} catch (const std::exception& e) {
return new OgaResult{e.what()};
}
OgaResult* OGA_API_CALL OgaCreateGeneratorParams(OgaModel* model, OgaGeneratorParams** out) {
OGA_TRY
*out = reinterpret_cast<OgaGeneratorParams*>(new Generators::GeneratorParams(*reinterpret_cast<Generators::Model*>(model)));
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGeneratorParamsSetMaxLength(OgaGeneratorParams* params, int max_length) {
reinterpret_cast<Generators::GeneratorParams*>(params)->max_length = static_cast<int>(max_length);
return nullptr;
}

OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* oga_params, const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size) {
OGA_TRY
auto& params = *reinterpret_cast<Generators::GeneratorParams*>(oga_params);
params.input_ids = std::span<const int32_t>(input_ids, input_ids_count);
params.sequence_length = static_cast<int>(sequence_length);
params.batch_size = static_cast<int>(batch_size);
if (params.sequence_length * params.batch_size != input_ids_count)
throw std::runtime_error("sequence length * batch size is not equal to input_ids_count");
return nullptr;
OGA_CATCH
}

OgaResult* OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) {
OGA_TRY
*out = reinterpret_cast<OgaGenerator*>(CreateGenerator(*reinterpret_cast<Generators::Model*>(model), *reinterpret_cast<const Generators::GeneratorParams*>(generator_params)).release());
return nullptr;
OGA_CATCH
}

bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator) {
return reinterpret_cast<const Generators::Generator*>(generator)->IsDone();
}

OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator) {
OGA_TRY
reinterpret_cast<Generators::Generator*>(generator)->ComputeLogits();
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator* generator) {
OGA_TRY
reinterpret_cast<Generators::Generator*>(generator)->GenerateNextToken_Top();
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerator_GetSequence(OgaGenerator* oga_generator, int index, int32_t* tokens, size_t* count) {
OGA_TRY
auto& generator = *reinterpret_cast<Generators::Generator*>(oga_generator);
auto sequence = generator.GetSequence(index);
auto sequence_cpu = sequence.GetCPU();
*count = sequence_cpu.size();
if (tokens)
std::copy(sequence_cpu.begin(), sequence_cpu.end(), tokens);
return nullptr;
OGA_CATCH
}

void OGA_API_CALL OgaDestroyResult(OgaResult* p) {
delete p;
}

void OGA_API_CALL OgaDestroyModel(OgaModel* p) {
delete reinterpret_cast<Generators::Model*>(p);
}

void OGA_API_CALL OgaDestroyGeneratorParams(OgaGeneratorParams* p) {
delete reinterpret_cast<Generators::GeneratorParams*>(p);
}

void OgaDestroyState(OgaState* state) {
delete reinterpret_cast<Generators::State*>(state);
void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* p) {
delete reinterpret_cast<Generators::Generator*>(p);
}
}
62 changes: 22 additions & 40 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,54 +26,36 @@ typedef enum OgaDeviceType {
OgaDeviceTypeCUDA,
} OgaDeviceType;

typedef enum OgaDataType {
OgaDataTypeFloat32,
OgaDataTypeInt64,
} OgaDataType;

typedef struct OgaResult OgaResult;
typedef struct OgaArray OgaArray;
typedef struct OgaSearchParams OgaSearchParams;
typedef struct OgaSearch OgaSearch;
typedef struct OgaGeneratorParams OgaGeneratorParams;
typedef struct OgaGenerator OgaGenerator;
typedef struct OgaModel OgaModel;
typedef struct OgaState OgaState;
typedef struct OgaRoamingArray OgaRoamingArray;

OGA_EXPORT void OGA_API_CALL OgaDestroyArray(OgaArray*);
OGA_EXPORT size_t OGA_API_CALL OgaArrayGetSize(OgaArray*);
OGA_EXPORT OgaDataType OGA_API_CALL OgaArrayGetType(OgaArray*);
OGA_EXPORT OgaDeviceType OGA_API_CALL OgaArrayGetNativeDeviceType(OgaArray*);
OGA_EXPORT void* OGA_API_CALL OgaArrayGetData(OgaArray*, OgaDeviceType*);

OGA_EXPORT const char* OGA_API_CALL OgaResultGetError(OgaResult*);
OGA_EXPORT void OGA_API_CALL OgaDestroyResult(OgaResult*);

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaDeviceType device_type, OgaModel** out);
OGA_EXPORT void OGA_API_CALL OgaDestroyModel(OgaModel*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaModelGenerate(OgaSearchParams* search_params);

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateSearchParams(OgaModel* model, OgaSearchParams** out);
OGA_EXPORT void OGA_API_CALL OgaDestroySearchParams(OgaSearchParams*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchParamsCreateSearch(OgaSearchParams*, OgaSearch** out);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchParamsSetInputIDs(OgaSearchParams*, int32_t* input_ids, size_t input_ids_count, int num_batches);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchParamsSetWhisperInputFeatures(OgaSearchParams*, int32_t* inputs, size_t count);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchParamsSetWhisperDecoderInputIDs(OgaSearchParams*, int32_t* input_ids, size_t input_ids_count);

OGA_EXPORT void OGA_API_CALL OgaDestroySearch(OgaSearch*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchSetLogits(OgaSearch*, OgaArray*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchGetSequenceLength(OgaSearch*, size_t* sequence_length);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchGetSequenceLengths(OgaSearch*, size_t* sequence_length);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchGetNextTokens(OgaSearch*, OgaArray**);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchGetNextIndices(OgaSearch*, OgaArray**);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchIsDone(OgaSearch*, bool* out);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchSelectTop(OgaSearch*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchSampleTopK(OgaSearch*, int k, float t);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchSampleTopP(OgaSearch*, float p, float t);
OGA_EXPORT OgaResult* OGA_API_CALL OgaSearchGetSequence(OgaSearch*, int index, OgaArray**);

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateState(OgaModel* model, int32_t* sequence_lengths, size_t sequence_lengths_count, const OgaSearchParams* search_params, OgaState** out);
OGA_EXPORT void OGA_API_CALL OgaDestroyState(OgaState*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaStateRun(int current_length, int32_t* next_tokens, size_t next_tokens_count, int32_t* next_indices, size_t next_indices_count, float* logits, float** logits_count);
OGA_EXPORT OgaResult* OGA_API_CALL OgaModelGenerate(OgaGeneratorParams* generator_params);

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGeneratorParams(OgaModel* model, OgaGeneratorParams** out);
OGA_EXPORT void OGA_API_CALL OgaDestroyGeneratorParams(OgaGeneratorParams*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetMaxLength(OgaGeneratorParams*, int max_length);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams*, const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorParams*, const int32_t* inputs, size_t count);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperDecoderInputIDs(OgaGeneratorParams*, const int32_t* input_ids, size_t input_ids_count);

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out);
OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator*);
OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK(OgaGenerator*, int k, float t);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopP(OgaGenerator*, float p, float t);

/* Writes the sequence into the provided buffer 'tokens' and writes the count into 'count'. If 'tokens' is nullptr just writes the count
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetSequence(OgaGenerator*, int index, int32_t* tokens, size_t* count);

#ifdef __cplusplus
}
Expand Down
88 changes: 88 additions & 0 deletions src/tests/c_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "../generators.h"
#include "../search.h"
#include "../models/model.h"
#include <iostream>
#include "../ort_genai_c.h"

// Our working directory is generators/build so one up puts us in the root directory:
#define MODEL_PATH "../test_models/"

struct Deleters {
void operator()(OgaResult* p) {
OgaDestroyResult(p);
}
void operator()(OgaModel* p) {
OgaDestroyModel(p);
}
void operator()(OgaGeneratorParams* p) {
OgaDestroyGeneratorParams(p);
}
void operator()(OgaGenerator* p) {
OgaDestroyGenerator(p);
}
};

using OgaResultPtr = std::unique_ptr<OgaResult, Deleters>;
using OgaModelPtr = std::unique_ptr<OgaModel, Deleters>;
using OgaGeneratorParamsPtr = std::unique_ptr<OgaGeneratorParams, Deleters>;
using OgaGeneratorPtr = std::unique_ptr<OgaGenerator, Deleters>;

void CheckResult(OgaResult* result) {
if (!result)
return;

OgaResultPtr result_ptr{result};
throw std::runtime_error(OgaResultGetError(result));
}

void Test_GreedySearch_Gpt_Fp32_C_API() {
std::cout << "Test_GreedySearch_Gpt fp32 C API" << std::flush;

std::vector<int64_t> input_ids_shape{2, 4};
std::vector<int32_t> input_ids{0, 0, 0, 52, 0, 0, 195, 731};

std::vector<int32_t> expected_output{
0, 0, 0, 52, 204, 204, 204, 204, 204, 204,
0, 0, 195, 731, 731, 114, 114, 114, 114, 114};

auto sequence_length = input_ids_shape[1];
auto batch_size = input_ids_shape[0];
int max_length = 10;

// To generate this file:
// python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20
// And copy the resulting gpt2_init_past_fp32.onnx file into these two files (as it's the same for gpt2)

OgaModel* model;
CheckResult(OgaCreateModel(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32", OgaDeviceTypeCPU, &model));
OgaModelPtr model_ptr{model};

OgaGeneratorParams* params;
CheckResult(OgaCreateGeneratorParams(model, &params));
OgaGeneratorParamsPtr params_ptr{params};
CheckResult(OgaGeneratorParamsSetMaxLength(params, max_length));
CheckResult(OgaGeneratorParamsSetInputIDs(params, input_ids.data(), input_ids.size(), sequence_length, batch_size));

OgaGenerator* generator;
CheckResult(OgaCreateGenerator(model, params, &generator));
OgaGeneratorPtr generator_ptr{generator};

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken_Top(generator));
}

// Verify outputs match expected outputs
for (int i = 0; i < batch_size; i++) {
size_t token_count;
CheckResult(OgaGenerator_GetSequence(generator, i, nullptr, &token_count));
std::vector<int32_t> sequence(token_count);
CheckResult(OgaGenerator_GetSequence(generator, i, sequence.data(), &token_count));

auto* expected_output_start = &expected_output[i * max_length];
if (!std::equal(expected_output_start, expected_output_start + max_length, sequence.begin(), sequence.end()))
throw std::runtime_error("Test Results Mismatch");
}

std::cout << " - complete\r\n";
}
2 changes: 2 additions & 0 deletions src/tests/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

extern std::unique_ptr<OrtEnv> g_ort_env;

void Test_GreedySearch_Gpt_Fp32_C_API();
void Test_GreedySearch_Gpt_Fp32();
void Test_BeamSearch_Gpt_Fp32();

Expand All @@ -25,6 +26,7 @@ int main() {
std::cout << "done" << std::endl;

try {
Test_GreedySearch_Gpt_Fp32_C_API();
Test_GreedySearch_Gpt_Fp32();
Test_BeamSearch_Gpt_Fp32();

Expand Down

0 comments on commit 9f0765b

Please sign in to comment.