diff --git a/CMakeLists.txt b/CMakeLists.txt index c36c4fefb..43b2dc8f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,6 +44,10 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER AND CMAKE_BUILD_TYPE STREQUAL "Debug") add_compile_definitions(_DEBUG=1) endif() +if(MSVC) + add_compile_options(/Zc:__cplusplus) # set updated value for __cplusplus macro instead of 199711L +endif() + message(STATUS "Adding source files") file(GLOB generator_srcs CONFIGURE_DEPENDS diff --git a/benchmark/c/main.cpp b/benchmark/c/main.cpp index 89fc72f33..04b654a9c 100644 --- a/benchmark/c/main.cpp +++ b/benchmark/c/main.cpp @@ -123,7 +123,9 @@ std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, cons params->SetInputSequences(*base_prompt_sequences); auto output_sequences = model.Generate(*params); - return std::string{tokenizer.Decode(output_sequences->Get(0))}; + const auto output_sequence_length = output_sequences->SequenceCount(0); + const auto* output_sequence_data = output_sequences->SequenceData(0); + return std::string{tokenizer.Decode(output_sequence_data, output_sequence_length)}; } void RunBenchmark(const benchmark::Options& opts) { @@ -148,7 +150,7 @@ void RunBenchmark(const benchmark::Options& opts) { tokenizer->Encode(prompt.c_str(), *prompt_sequences); } - const size_t num_prompt_tokens = prompt_sequences->Get(0).size(); + const size_t num_prompt_tokens = prompt_sequences->SequenceCount(0); const size_t num_tokens = num_prompt_tokens + opts.num_tokens_to_generate; auto make_generator_params = [&] { @@ -169,7 +171,9 @@ void RunBenchmark(const benchmark::Options& opts) { if (opts.verbose && i == 0) { // show prompt and output on first iteration std::cout << "Prompt:\n\t" << prompt << "\n"; - auto output = tokenizer->Decode(output_sequences->Get(0)); + const auto output_sequence_length = output_sequences->SequenceCount(0); + const auto* output_sequence_data = output_sequences->SequenceData(0); + const auto output = tokenizer->Decode(output_sequence_data, output_sequence_length); std::cout << "Output:\n\t" << output << "\n"; } } diff --git a/src/ort_genai.h b/src/ort_genai.h index 4c0dddf0c..265a52e21 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -4,10 +4,13 @@ #pragma once #include -#include "span.h" // TODO Use the std span header if we can assume C++20. #include #include +#if __cplusplus >= 202002L +#include +#endif + #include "ort_genai_c.h" // GenAI C++ API @@ -92,9 +95,19 @@ struct OgaSequences : OgaAbstract { return OgaSequencesCount(this); } + size_t SequenceCount(size_t index) const { + return OgaSequencesGetSequenceCount(this, index); + } + + const int32_t* SequenceData(size_t index) const { + return OgaSequencesGetSequenceData(this, index); + } + +#if __cplusplus >= 202002L std::span Get(size_t index) const { - return {OgaSequencesGetSequenceData(this, index), OgaSequencesGetSequenceCount(this, index)}; + return {SequenceData(index), SequenceCount(index)}; } +#endif static void operator delete(void* p) { OgaDestroySequences(reinterpret_cast(p)); } }; @@ -110,11 +123,19 @@ struct OgaTokenizer : OgaAbstract { OgaCheckResult(OgaTokenizerEncode(this, str, &sequences)); } + OgaString Decode(const int32_t* tokens_data, size_t tokens_length) const { + const char* p; + OgaCheckResult(OgaTokenizerDecode(this, tokens_data, tokens_length, &p)); + return p; + } + +#if __cplusplus >= 202002L OgaString Decode(std::span tokens) const { const char* p; OgaCheckResult(OgaTokenizerDecode(this, tokens.data(), tokens.size(), &p)); return p; } +#endif static void operator delete(void* p) { OgaDestroyTokenizer(reinterpret_cast(p)); } }; @@ -190,9 +211,19 @@ struct OgaGenerator : OgaAbstract { OgaCheckResult(OgaGenerator_GenerateNextToken(this)); } + size_t GetSequenceLength(size_t index) const { + return OgaGenerator_GetSequenceLength(this, index); + } + + const int32_t* GetSequenceData(size_t index) const { + return OgaGenerator_GetSequence(this, index); + } + +#if __cplusplus >= 202002L std::span GetSequence(size_t index) const { - return {OgaGenerator_GetSequence(this, index), OgaGenerator_GetSequenceLength(this, index)}; + return {GetSequenceData(index), GetSequenceLength(index)}; } +#endif static void operator delete(void* p) { OgaDestroyGenerator(reinterpret_cast(p)); } }; diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index a1ec2b923..2413c9f1d 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -120,10 +120,13 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { // Verify outputs match expected outputs for (int i = 0; i < batch_size; i++) { - auto sequence = generator->GetSequence(i); + const auto sequence_length = generator->GetSequenceLength(i); + const auto* sequence_data = generator->GetSequenceData(i); - auto* expected_output_start = &expected_output[i * max_length]; - EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), max_length * sizeof(int32_t))); + ASSERT_LE(sequence_length, max_length); + + const auto* expected_output_start = &expected_output[i * max_length]; + EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); } // Test high level API @@ -131,10 +134,13 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { // Verify outputs match expected outputs for (int i = 0; i < batch_size; i++) { - auto sequence = sequences->Get(i); + const auto sequence_length = sequences->SequenceCount(i); + const auto* sequence_data = sequences->SequenceData(i); + + ASSERT_LE(sequence_length, max_length); - auto* expected_output_start = &expected_output[i * max_length]; - EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), max_length * sizeof(int32_t))); + const auto* expected_output_start = &expected_output[i * max_length]; + EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); } }