Skip to content

Commit

Permalink
Remove GenerateNextToken* special case functions (#221)
Browse files Browse the repository at this point in the history
set_search_options already supports the functionality, so the extra
functions are confusing users since there are multiple ways to do the
same thing.
set_search_options is also more flexible as it supports all future
options without the need for extra APIs.
  • Loading branch information
RyanUnderhill authored and jchen351 committed Mar 23, 2024
1 parent 118d352 commit 08947b7
Show file tree
Hide file tree
Showing 12 changed files with 371 additions and 98 deletions.
2 changes: 1 addition & 1 deletion src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct Config {
int num_return_sequences{1};
float repetition_penalty{1.0f}; // 1.0 means no penalty.
int top_k{}; // Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model.
float top_p{1.0f}; // If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
float top_p{}; // If set to float >0 and <1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
float temperature{1.0f};
bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not.
int no_repeat_ngram_size{};
Expand Down
5 changes: 2 additions & 3 deletions src/csharp/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand All @@ -26,9 +25,9 @@ public void ComputeLogits()
Result.VerifySuccess(NativeMethods.OgaGenerator_ComputeLogits(_generatorHandle));
}

public void GenerateNextTokenTop()
public void GenerateNextToken()
{
Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_Top(_generatorHandle));
Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken(_generatorHandle));
}

public ReadOnlySpan<int> GetSequence(ulong index)
Expand Down
2 changes: 1 addition & 1 deletion src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ internal class NativeLib

// This function is used to generate the next token in the sequence using the greedy search algorithm.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_Top(IntPtr /* OgaGenerator* */ generator);
public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken(IntPtr /* OgaGenerator* */ generator);

// This function returns the length of the sequence at the given index.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
Expand Down
37 changes: 14 additions & 23 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_

void Generator::ComputeLogits() {
if (computed_logits_)
throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken* first");
throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first");

search_->SetLogits(state_->Run(search_->GetSequenceLength(), search_->GetNextTokens(), search_->GetNextIndices()));
computed_logits_ = true;
Expand All @@ -101,46 +101,37 @@ bool Generator::IsDone() const {
return search_->IsDone();
}

void Generator::GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature) {
void Generator::GenerateNextToken() {
if (!computed_logits_)
throw std::runtime_error("Must call ComputeLogits before GenerateNextToken*");
throw std::runtime_error("Must call ComputeLogits before GenerateNextToken");
computed_logits_ = false;

if (top_k == 1) {
auto& search = search_->params_->search;
if (!search.do_sample || search.top_k == 1) {
search_->SelectTop();
return;
}

// The user explicitly called TopK_TopP on a beam search
if (search_->params_->search.num_beams != 1)
if (search.num_beams != 1)
throw std::runtime_error("TopK and TopP cannot be used with a beam search");

// Sanity checks
if (top_p < 0.0f || top_p > 1.0f)
if (search.top_p < 0.0f || search.top_p > 1.0f)
throw std::runtime_error("top_p must be between 0.0 and 1.0");
if (top_k < 0)
if (search.top_k < 0)
throw std::runtime_error("top_k must be 0 or greater");

if (top_p > 0.0f && top_k > 1) {
search_->SampleTopKTopP(top_k, top_p, temperature);
} else if (top_k > 1) {
search_->SampleTopK(top_k, temperature);
if (search.top_p > 0.0f && search.top_p < 1.0f && search.top_k > 1) {
search_->SampleTopKTopP(search.top_k, search.top_p, search.temperature);
} else if (search.top_k > 1) {
search_->SampleTopK(search.top_k, search.temperature);
} else {
assert(top_k == 0);
if (top_p == 0.0f)
throw std::runtime_error("top_k and top_p cannot both be zero");
search_->SampleTopP(top_p, temperature);
assert(search.top_k == 0);
search_->SampleTopP(search.top_p, search.temperature);
}
}

void Generator::GenerateNextToken() {
auto& search = search_->params_->search;
if (search.do_sample)
GenerateNextToken_TopK_TopP(search.top_k, search.top_p, search.temperature);
else
GenerateNextToken_Top();
}

RoamingArray<int32_t> Generator::GetSequence(int index) const {
return search_->GetSequence(index);
}
Expand Down
4 changes: 0 additions & 4 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ struct Generator {

bool IsDone() const;
void ComputeLogits();
void GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature);
void GenerateNextToken_TopP(float p, float temperature) { GenerateNextToken_TopK_TopP(0, p, temperature); }
void GenerateNextToken_TopK(int k, float temperature) { GenerateNextToken_TopK_TopP(k, 0.0f, temperature); }
void GenerateNextToken_Top() { GenerateNextToken_TopK_TopP(1, 0.0f, 0.0f); }
void GenerateNextToken();

RoamingArray<int32_t> GetSequence(int index) const;
Expand Down
4 changes: 2 additions & 2 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator) {
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator* generator) {
OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) {
OGA_TRY
reinterpret_cast<Generators::Generator*>(generator)->GenerateNextToken_Top();
reinterpret_cast<Generators::Generator*>(generator)->GenerateNextToken();
return nullptr;
OGA_CATCH
}
Expand Down
11 changes: 1 addition & 10 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,7 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator);
* \return OgaResult containing the error message if the computation of the logits failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator);

/*
* \brief Generates the next token based on the computed logits using the greedy search.
* \param[in] generator The generator to generate the next token for.
* \return OgaResult containing the error message if the generation of the next token failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator* generator);

OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK(OgaGenerator* generator, int k, float t);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopP(OgaGenerator* generator, float p, float t);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator);

/*
* \brief Returns the number of tokens in the sequence at the given index.
Expand Down
20 changes: 0 additions & 20 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,6 @@ struct PyGenerator {
generator_->ComputeLogits();
}

void GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature) {
generator_->GenerateNextToken_TopK_TopP(top_k, top_p, temperature);
}

void GenerateNextToken_TopP(float p, float temperature) {
generator_->GenerateNextToken_TopP(p, temperature);
}

void GenerateNextToken_TopK(int k, float temperature) {
generator_->GenerateNextToken_TopK(k, temperature);
}

void GenerateNextToken_Top() {
generator_->GenerateNextToken_Top();
}

void GenerateNextToken() {
generator_->GenerateNextToken();
}
Expand Down Expand Up @@ -235,10 +219,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def("is_done", &PyGenerator::IsDone)
.def("compute_logits", &PyGenerator::ComputeLogits)
.def("generate_next_token", &PyGenerator::GenerateNextToken)
.def("generate_next_token_top", &PyGenerator::GenerateNextToken_Top)
.def("generate_next_token_top_p", &PyGenerator::GenerateNextToken_TopP)
.def("generate_next_token_top_k", &PyGenerator::GenerateNextToken_TopK)
.def("generate_next_token_top_k_top_p", &PyGenerator::GenerateNextToken_TopK_TopP)
.def("get_next_tokens", &PyGenerator::GetNextTokens)
.def("get_sequence", &PyGenerator::GetSequence);

Expand Down
153 changes: 152 additions & 1 deletion test/c_api_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) {
CheckResult(OgaCreateGeneratorParams(model, &params));
OgaGeneratorParamsPtr params_ptr{params};
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", max_length));
CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", false));
CheckResult(OgaGeneratorParamsSetInputIDs(params, input_ids.data(), input_ids.size(), sequence_length, batch_size));

OgaGenerator* generator;
Expand All @@ -195,7 +196,7 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) {

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken_Top(generator));
CheckResult(OgaGenerator_GenerateNextToken(generator));
}

// Verify outputs match expected outputs
Expand All @@ -221,3 +222,153 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) {
EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), max_length * sizeof(int32_t)));
}
}

#if TEST_PHI2
TEST(CAPITests, TopKCAPI) {
float top_k = 50;
float temp = 0.6f;

OgaModel* model;
CheckResult(OgaCreateModel(MODEL_PATH "phi-2", &model));
OgaModelPtr model_ptr{model};

OgaTokenizer* tokenizer;
CheckResult(OgaCreateTokenizer(model, &tokenizer));
OgaTokenizerPtr tokenizer_ptr{tokenizer};

OgaSequences* input_sequences;
CheckResult(OgaCreateSequences(&input_sequences));
OgaSequencesPtr sequences_ptr{input_sequences};

const char* input_strings[] = {
"This is a test.",
"Rats are awesome pets!",
"The quick brown fox jumps over the lazy dog.",
};

for (auto& string : input_strings)
CheckResult(OgaTokenizerEncode(tokenizer, string, input_sequences));

OgaGeneratorParams* params;
CheckResult(OgaCreateGeneratorParams(model, &params));
OgaGeneratorParamsPtr params_ptr{params};
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40));
CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_k", top_k));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences));

OgaSequences* output_sequences;
CheckResult(OgaGenerate(model, params, &output_sequences));
OgaSequencesPtr output_sequences_ptr{output_sequences};

// Decode The Batch
for (size_t i = 0; i < OgaSequencesCount(output_sequences); i++) {
std::span<const int32_t> sequence{OgaSequencesGetSequenceData(output_sequences, i), OgaSequencesGetSequenceCount(output_sequences, i)};

const char* out_string;
CheckResult(OgaTokenizerDecode(tokenizer, sequence.data(), sequence.size(), &out_string));
std::cout << "Decoded string:" << out_string << std::endl;
OgaDestroyString(out_string);
}
}

TEST(CAPITests, TopPCAPI) {
float top_p = 0.6f;
float temp = 0.6f;

OgaModel* model;
CheckResult(OgaCreateModel(MODEL_PATH "phi-2", &model));
OgaModelPtr model_ptr{model};

OgaTokenizer* tokenizer;
CheckResult(OgaCreateTokenizer(model, &tokenizer));
OgaTokenizerPtr tokenizer_ptr{tokenizer};

OgaSequences* input_sequences;
CheckResult(OgaCreateSequences(&input_sequences));
OgaSequencesPtr sequences_ptr{input_sequences};

const char* input_strings[] = {
"This is a test.",
"Rats are awesome pets!",
"The quick brown fox jumps over the lazy dog.",
};

for (auto& string : input_strings)
CheckResult(OgaTokenizerEncode(tokenizer, string, input_sequences));

OgaGeneratorParams* params;
CheckResult(OgaCreateGeneratorParams(model, &params));
OgaGeneratorParamsPtr params_ptr{params};
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40));
CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_p", top_p));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences));
OgaSequences* output_sequences;
CheckResult(OgaGenerate(model, params, &output_sequences));
OgaSequencesPtr output_sequences_ptr{output_sequences};

// Decode The Batch
for (size_t i = 0; i < OgaSequencesCount(output_sequences); i++) {
std::span<const int32_t> sequence{OgaSequencesGetSequenceData(output_sequences, i), OgaSequencesGetSequenceCount(output_sequences, i)};

const char* out_string;
CheckResult(OgaTokenizerDecode(tokenizer, sequence.data(), sequence.size(), &out_string));
std::cout << "Decoded string:" << out_string << std::endl;
OgaDestroyString(out_string);
}
}

TEST(CAPITests, TopKTopPCAPI) {
float top_p = 0.6f;
int top_k = 50;
float temp = 0.6f;

OgaModel* model;
CheckResult(OgaCreateModel(MODEL_PATH "phi-2", &model));
OgaModelPtr model_ptr{model};

OgaTokenizer* tokenizer;
CheckResult(OgaCreateTokenizer(model, &tokenizer));
OgaTokenizerPtr tokenizer_ptr{tokenizer};

OgaSequences* input_sequences;
CheckResult(OgaCreateSequences(&input_sequences));
OgaSequencesPtr sequences_ptr{input_sequences};

const char* input_strings[] = {
"This is a test.",
"Rats are awesome pets!",
"The quick brown fox jumps over the lazy dog.",
};

for (auto& string : input_strings)
CheckResult(OgaTokenizerEncode(tokenizer, string, input_sequences));

OgaGeneratorParams* params;
CheckResult(OgaCreateGeneratorParams(model, &params));
OgaGeneratorParamsPtr params_ptr{params};
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40));
CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_k", top_k));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_p", top_p));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences));
OgaSequences* output_sequences;
CheckResult(OgaGenerate(model, params, &output_sequences));
OgaSequencesPtr output_sequences_ptr{output_sequences};

// Decode The Batch
for (size_t i = 0; i < OgaSequencesCount(output_sequences); i++) {
std::span<const int32_t> sequence{OgaSequencesGetSequenceData(output_sequences, i), OgaSequencesGetSequenceCount(output_sequences, i)};

const char* out_string;
CheckResult(OgaTokenizerDecode(tokenizer, sequence.data(), sequence.size(), &out_string));
std::cout << "Decoded string:" << out_string << std::endl;
OgaDestroyString(out_string);
}
}

#endif // TEST_PHI2
Loading

0 comments on commit 08947b7

Please sign in to comment.