Skip to content

Commit

Permalink
add pre C++20 APIs that don't use span
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Apr 2, 2024
1 parent b78085b commit 05510c0
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
add_compile_definitions(_DEBUG=1)
endif()

if(MSVC)
add_compile_options(/Zc:__cplusplus) # set updated value for __cplusplus macro instead of 199711L
endif()

message(STATUS "Adding source files")

file(GLOB generator_srcs CONFIGURE_DEPENDS
Expand Down
10 changes: 7 additions & 3 deletions benchmark/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, cons
params->SetInputSequences(*base_prompt_sequences);

auto output_sequences = model.Generate(*params);
return std::string{tokenizer.Decode(output_sequences->Get(0))};
const auto output_sequence_length = output_sequences->SequenceCount(0);
const auto* output_sequence_data = output_sequences->SequenceData(0);
return std::string{tokenizer.Decode(output_sequence_data, output_sequence_length)};
}

void RunBenchmark(const benchmark::Options& opts) {
Expand All @@ -148,7 +150,7 @@ void RunBenchmark(const benchmark::Options& opts) {
tokenizer->Encode(prompt.c_str(), *prompt_sequences);
}

const size_t num_prompt_tokens = prompt_sequences->Get(0).size();
const size_t num_prompt_tokens = prompt_sequences->SequenceCount(0);
const size_t num_tokens = num_prompt_tokens + opts.num_tokens_to_generate;

auto make_generator_params = [&] {
Expand All @@ -169,7 +171,9 @@ void RunBenchmark(const benchmark::Options& opts) {
if (opts.verbose && i == 0) {
// show prompt and output on first iteration
std::cout << "Prompt:\n\t" << prompt << "\n";
auto output = tokenizer->Decode(output_sequences->Get(0));
const auto output_sequence_length = output_sequences->SequenceCount(0);
const auto* output_sequence_data = output_sequences->SequenceData(0);
const auto output = tokenizer->Decode(output_sequence_data, output_sequence_length);
std::cout << "Output:\n\t" << output << "\n";
}
}
Expand Down
37 changes: 34 additions & 3 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
#pragma once

#include <memory>
#include "span.h" // TODO Use the std span header if we can assume C++20.
#include <stdexcept>
#include <type_traits>

#if __cplusplus >= 202002L
#include <span>
#endif

#include "ort_genai_c.h"

// GenAI C++ API
Expand Down Expand Up @@ -92,9 +95,19 @@ struct OgaSequences : OgaAbstract {
return OgaSequencesCount(this);
}

size_t SequenceCount(size_t index) const {
return OgaSequencesGetSequenceCount(this, index);
}

const int32_t* SequenceData(size_t index) const {
return OgaSequencesGetSequenceData(this, index);
}

#if __cplusplus >= 202002L
std::span<const int32_t> Get(size_t index) const {
return {OgaSequencesGetSequenceData(this, index), OgaSequencesGetSequenceCount(this, index)};
return {SequenceData(index), SequenceCount(index)};
}
#endif

static void operator delete(void* p) { OgaDestroySequences(reinterpret_cast<OgaSequences*>(p)); }
};
Expand All @@ -110,11 +123,19 @@ struct OgaTokenizer : OgaAbstract {
OgaCheckResult(OgaTokenizerEncode(this, str, &sequences));
}

OgaString Decode(const int32_t* tokens_data, size_t tokens_length) const {
const char* p;
OgaCheckResult(OgaTokenizerDecode(this, tokens_data, tokens_length, &p));
return p;
}

#if __cplusplus >= 202002L
OgaString Decode(std::span<const int32_t> tokens) const {
const char* p;
OgaCheckResult(OgaTokenizerDecode(this, tokens.data(), tokens.size(), &p));
return p;
}
#endif

static void operator delete(void* p) { OgaDestroyTokenizer(reinterpret_cast<OgaTokenizer*>(p)); }
};
Expand Down Expand Up @@ -190,9 +211,19 @@ struct OgaGenerator : OgaAbstract {
OgaCheckResult(OgaGenerator_GenerateNextToken(this));
}

size_t GetSequenceLength(size_t index) const {
return OgaGenerator_GetSequenceLength(this, index);
}

const int32_t* GetSequenceData(size_t index) const {
return OgaGenerator_GetSequence(this, index);
}

#if __cplusplus >= 202002L
std::span<const int32_t> GetSequence(size_t index) const {
return {OgaGenerator_GetSequence(this, index), OgaGenerator_GetSequenceLength(this, index)};
return {GetSequenceData(index), GetSequenceLength(index)};
}
#endif

static void operator delete(void* p) { OgaDestroyGenerator(reinterpret_cast<OgaGenerator*>(p)); }
};
18 changes: 12 additions & 6 deletions test/c_api_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,27 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) {

// Verify outputs match expected outputs
for (int i = 0; i < batch_size; i++) {
auto sequence = generator->GetSequence(i);
const auto sequence_length = generator->GetSequenceLength(i);
const auto* sequence_data = generator->GetSequenceData(i);

auto* expected_output_start = &expected_output[i * max_length];
EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), max_length * sizeof(int32_t)));
ASSERT_LE(sequence_length, max_length);

const auto* expected_output_start = &expected_output[i * max_length];
EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t)));
}

// Test high level API
auto sequences = model->Generate(*params);

// Verify outputs match expected outputs
for (int i = 0; i < batch_size; i++) {
auto sequence = sequences->Get(i);
const auto sequence_length = sequences->SequenceCount(i);
const auto* sequence_data = sequences->SequenceData(i);

ASSERT_LE(sequence_length, max_length);

auto* expected_output_start = &expected_output[i * max_length];
EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), max_length * sizeof(int32_t)));
const auto* expected_output_start = &expected_output[i * max_length];
EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t)));
}
}

Expand Down

0 comments on commit 05510c0

Please sign in to comment.