Skip to content

Commit

Permalink
Remove GenerateNextToken* special case functions, as the set_search_o…
Browse files Browse the repository at this point in the history
…ptions methods should be used instead.
  • Loading branch information
RyanUnderhill committed Mar 21, 2024
1 parent 3a9ecf7 commit a8720b5
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 192 deletions.
2 changes: 1 addition & 1 deletion src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct Config {
int num_return_sequences{1};
float repetition_penalty{1.0f}; // 1.0 means no penalty.
int top_k{}; // Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model.
float top_p{1.0f}; // If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
float top_p{}; // If set to float >0 and <1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
float temperature{1.0f};
bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not.
int no_repeat_ngram_size{};
Expand Down
20 changes: 0 additions & 20 deletions src/csharp/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,6 @@ public void GenerateNextToken()
Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken(_generatorHandle));
}

public void GenerateNextTokenTop()
{
Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_Top(_generatorHandle));
}

public void GenerateNextTokenTopK(int k, float temperature)
{
Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_TopK(_generatorHandle, k, temperature));
}

public void GenerateNextTokenTopP(float p, float temperature)
{
Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_TopP(_generatorHandle, p, temperature));
}

public void GenerateNextTokenTopKTopP(int k, float p, float temperature)
{
Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_TopK_TopP(_generatorHandle, k, p, temperature));
}

public ReadOnlySpan<int> GetSequence(ulong index)
{
ulong sequenceLength = NativeMethods.OgaGenerator_GetSequenceLength(_generatorHandle, (UIntPtr)index).ToUInt64();
Expand Down
23 changes: 0 additions & 23 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,29 +80,6 @@ internal class NativeLib
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken(IntPtr /* OgaGenerator* */ generator);

// This function is used to generate the next token in the sequence using the greedy search algorithm.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_Top(IntPtr /* OgaGenerator* */ generator);

// This function is used to generate the next token in the sequence using the greedy search algorithm.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_TopK(IntPtr /* OgaGenerator* */ generator,
int /* int32_t */ k,
float /* single_t */ t);

// This function is used to generate the next token in the sequence using the greedy search algorithm.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_TopP(IntPtr /* OgaGenerator* */ generator,
float /* single_t */ p,
float /* single_t */ t);

// This function is used to generate the next token in the sequence using the greedy search algorithm.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_TopK_TopP(IntPtr /* OgaGenerator* */ generator,
int /* int32_t */ k,
float /* single_t */ p,
float /* single_t */ t);

// This function returns the length of the sequence at the given index.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern UIntPtr /* size_t */ OgaGenerator_GetSequenceLength(IntPtr /* const OgaGenerator* */ generator,
Expand Down
37 changes: 14 additions & 23 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_

void Generator::ComputeLogits() {
if (computed_logits_)
throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken* first");
throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first");

search_->SetLogits(state_->Run(search_->GetSequenceLength(), search_->GetNextTokens(), search_->GetNextIndices()));
computed_logits_ = true;
Expand All @@ -101,46 +101,37 @@ bool Generator::IsDone() const {
return search_->IsDone();
}

void Generator::GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature) {
void Generator::GenerateNextToken() {
if (!computed_logits_)
throw std::runtime_error("Must call ComputeLogits before GenerateNextToken*");
throw std::runtime_error("Must call ComputeLogits before GenerateNextToken");
computed_logits_ = false;

if (top_k == 1) {
auto& search = search_->params_->search;
if (!search.do_sample || search.top_k == 1) {
search_->SelectTop();
return;
}

// The user explicitly called TopK_TopP on a beam search
if (search_->params_->search.num_beams != 1)
if (search.num_beams != 1)
throw std::runtime_error("TopK and TopP cannot be used with a beam search");

// Sanity checks
if (top_p < 0.0f || top_p > 1.0f)
if (search.top_p < 0.0f || search.top_p > 1.0f)
throw std::runtime_error("top_p must be between 0.0 and 1.0");
if (top_k < 0)
if (search.top_k < 0)
throw std::runtime_error("top_k must be 0 or greater");

if (top_p > 0.0f && top_p < 1.0f && top_k > 1) {
search_->SampleTopKTopP(top_k, top_p, temperature);
} else if (top_k > 1) {
search_->SampleTopK(top_k, temperature);
if (search.top_p > 0.0f && search.top_p < 1.0f && search.top_k > 1) {
search_->SampleTopKTopP(search.top_k, search.top_p, search.temperature);
} else if (search.top_k > 1) {
search_->SampleTopK(search.top_k, search.temperature);
} else {
assert(top_k == 0);
if (top_p == 0.0f)
throw std::runtime_error("top_k and top_p cannot both be zero");
search_->SampleTopP(top_p, temperature);
assert(search.top_k == 0);
search_->SampleTopP(search.top_p, search.temperature);
}
}

void Generator::GenerateNextToken() {
auto& search = search_->params_->search;
if (search.do_sample)
GenerateNextToken_TopK_TopP(search.top_k, search.top_p, search.temperature);
else
GenerateNextToken_Top();
}

RoamingArray<int32_t> Generator::GetSequence(int index) const {
return search_->GetSequence(index);
}
Expand Down
4 changes: 0 additions & 4 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ struct Generator {

bool IsDone() const;
void ComputeLogits();
void GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature);
void GenerateNextToken_TopP(float p, float temperature) { GenerateNextToken_TopK_TopP(0, p, temperature); }
void GenerateNextToken_TopK(int k, float temperature) { GenerateNextToken_TopK_TopP(k, 0.0f, temperature); }
void GenerateNextToken_Top() { GenerateNextToken_TopK_TopP(1, 0.0f, 0.0f); }
void GenerateNextToken();

RoamingArray<int32_t> GetSequence(int index) const;
Expand Down
28 changes: 0 additions & 28 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,34 +156,6 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator)
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator* generator) {
OGA_TRY
reinterpret_cast<Generators::Generator*>(generator)->GenerateNextToken_Top();
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK(OgaGenerator* generator, int k, float t) {
OGA_TRY
reinterpret_cast<Generators::Generator*>(generator)->GenerateNextToken_TopK(k, t);
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopP(OgaGenerator* generator, float p, float t) {
OGA_TRY
reinterpret_cast<Generators::Generator*>(generator)->GenerateNextToken_TopP(p, t);
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK_TopP(OgaGenerator* generator, int k, float p, float t) {
OGA_TRY
reinterpret_cast<Generators::Generator*>(generator)->GenerateNextToken_TopK_TopP(k, p, t);
return nullptr;
OGA_CATCH
}

size_t OGA_API_CALL OgaGenerator_GetSequenceLength(const OgaGenerator* oga_generator, size_t index) {
auto& generator = *reinterpret_cast<const Generators::Generator*>(oga_generator);
return generator.GetSequence(static_cast<int>(index)).GetCPU().size();
Expand Down
17 changes: 0 additions & 17 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,6 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator);
* \return OgaResult containing the error message if the computation of the logits failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator);

/*
* \brief Generates the next token based on the computed logits using the greedy search.
* \param[in] generator The generator to generate the next token for.
* \return OgaResult containing the error message if the generation of the next token failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator* generator);

/* Top-K sampling: most probable words from the model's output probability distribution for the next word
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK(OgaGenerator* generator, int k, float t);

/*Top-P sampling selects words from the smallest set of words whose cumulative probability exceeds a predefined threshold (p)
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopP(OgaGenerator* generator, float p, float t);

OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK_TopP(OgaGenerator* generator, int k, float p, float t);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator);

/*
Expand Down
20 changes: 0 additions & 20 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,6 @@ struct PyGenerator {
generator_->ComputeLogits();
}

void GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature) {
generator_->GenerateNextToken_TopK_TopP(top_k, top_p, temperature);
}

void GenerateNextToken_TopP(float p, float temperature) {
generator_->GenerateNextToken_TopP(p, temperature);
}

void GenerateNextToken_TopK(int k, float temperature) {
generator_->GenerateNextToken_TopK(k, temperature);
}

void GenerateNextToken_Top() {
generator_->GenerateNextToken_Top();
}

void GenerateNextToken() {
generator_->GenerateNextToken();
}
Expand Down Expand Up @@ -235,10 +219,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def("is_done", &PyGenerator::IsDone)
.def("compute_logits", &PyGenerator::ComputeLogits)
.def("generate_next_token", &PyGenerator::GenerateNextToken)
.def("generate_next_token_top", &PyGenerator::GenerateNextToken_Top)
.def("generate_next_token_top_p", &PyGenerator::GenerateNextToken_TopP)
.def("generate_next_token_top_k", &PyGenerator::GenerateNextToken_TopK)
.def("generate_next_token_top_k_top_p", &PyGenerator::GenerateNextToken_TopK_TopP)
.def("get_next_tokens", &PyGenerator::GetNextTokens)
.def("get_sequence", &PyGenerator::GetSequence);

Expand Down
40 changes: 6 additions & 34 deletions test/c_api_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) {
CheckResult(OgaCreateGeneratorParams(model, &params));
OgaGeneratorParamsPtr params_ptr{params};
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", max_length));
CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", false));
CheckResult(OgaGeneratorParamsSetInputIDs(params, input_ids.data(), input_ids.size(), sequence_length, batch_size));

OgaGenerator* generator;
Expand All @@ -195,7 +196,7 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) {

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken_Top(generator));
CheckResult(OgaGenerator_GenerateNextToken(generator));
}

// Verify outputs match expected outputs
Expand Down Expand Up @@ -252,20 +253,11 @@ TEST(CAPITests, TopKCAPI) {
CheckResult(OgaCreateGeneratorParams(model, &params));
OgaGeneratorParamsPtr params_ptr{params};
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences));

OgaGenerator* generator;
CheckResult(OgaCreateGenerator(model, params, &generator));
OgaGeneratorPtr generator_ptr{generator};

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken_TopK(generator, top_k, temp));
}

CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_k", top_k));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences));

OgaSequences* output_sequences;
CheckResult(OgaGenerate(model, params, &output_sequences));
OgaSequencesPtr output_sequences_ptr{output_sequences};
Expand Down Expand Up @@ -310,20 +302,10 @@ TEST(CAPITests, TopPCAPI) {
CheckResult(OgaCreateGeneratorParams(model, &params));
OgaGeneratorParamsPtr params_ptr{params};
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences));

OgaGenerator* generator;
CheckResult(OgaCreateGenerator(model, params, &generator));
OgaGeneratorPtr generator_ptr{generator};

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken_TopP(generator, top_p, temp));
}

CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_p", top_p));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences));
OgaSequences* output_sequences;
CheckResult(OgaGenerate(model, params, &output_sequences));
OgaSequencesPtr output_sequences_ptr{output_sequences};
Expand Down Expand Up @@ -369,21 +351,11 @@ TEST(CAPITests, TopKTopPCAPI) {
CheckResult(OgaCreateGeneratorParams(model, &params));
OgaGeneratorParamsPtr params_ptr{params};
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences));

OgaGenerator* generator;
CheckResult(OgaCreateGenerator(model, params, &generator));
OgaGeneratorPtr generator_ptr{generator};

while (!OgaGenerator_IsDone(generator)) {
CheckResult(OgaGenerator_ComputeLogits(generator));
CheckResult(OgaGenerator_GenerateNextToken_TopK_TopP(generator, top_k, top_p, temp));
}

CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_k", top_k));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_p", top_p));
CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp));
CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences));
OgaSequences* output_sequences;
CheckResult(OgaGenerate(model, params, &output_sequences));
OgaSequencesPtr output_sequences_ptr{output_sequences};
Expand Down
6 changes: 3 additions & 3 deletions test/model_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TEST(ModelTests, GreedySearchGptFp32) {

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken_Top();
generator->GenerateNextToken();
}

// Verify outputs match expected outputs
Expand Down Expand Up @@ -128,7 +128,7 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label)

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken_Top();
generator->GenerateNextToken();
}

// Verify outputs match expected outputs
Expand Down Expand Up @@ -226,7 +226,7 @@ Print all primes between 1 and n
auto generator = Generators::CreateGenerator(*model, *params);
while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken_Top();
generator->GenerateNextToken();
}

auto result = generator->GetSequence(0);
Expand Down
Loading

0 comments on commit a8720b5

Please sign in to comment.