From a8720b5343c962e652ee3cc31668465df606d12a Mon Sep 17 00:00:00 2001 From: Ryan Hill Date: Wed, 20 Mar 2024 23:48:49 -0700 Subject: [PATCH] Remove GenerateNextToken* special case functions, as the set_search_options methods should be used instead. --- src/config.h | 2 +- src/csharp/Generator.cs | 20 ------------- src/csharp/NativeMethods.cs | 23 -------------- src/generators.cpp | 37 +++++++++-------------- src/generators.h | 4 --- src/ort_genai_c.cpp | 28 ----------------- src/ort_genai_c.h | 17 ----------- src/python/python.cpp | 20 ------------- test/c_api_tests.cpp | 40 ++++--------------------- test/model_tests.cpp | 6 ++-- test/sampling_tests.cpp | 60 +++++++++++++++++++++++++------------ 11 files changed, 65 insertions(+), 192 deletions(-) diff --git a/src/config.h b/src/config.h index 6cc634658..2621edc21 100644 --- a/src/config.h +++ b/src/config.h @@ -79,7 +79,7 @@ struct Config { int num_return_sequences{1}; float repetition_penalty{1.0f}; // 1.0 means no penalty. int top_k{}; // Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model. - float top_p{1.0f}; // If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + float top_p{}; // If set to float >0 and <1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. float temperature{1.0f}; bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not. int no_repeat_ngram_size{}; diff --git a/src/csharp/Generator.cs b/src/csharp/Generator.cs index 10c3d4e47..64c1c5623 100644 --- a/src/csharp/Generator.cs +++ b/src/csharp/Generator.cs @@ -30,26 +30,6 @@ public void GenerateNextToken() Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken(_generatorHandle)); } - public void GenerateNextTokenTop() - { - Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_Top(_generatorHandle)); - } - - public void GenerateNextTokenTopK(int k, float temperature) - { - Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_TopK(_generatorHandle, k, temperature)); - } - - public void GenerateNextTokenTopP(float p, float temperature) - { - Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_TopP(_generatorHandle, p, temperature)); - } - - public void GenerateNextTokenTopKTopP(int k, float p, float temperature) - { - Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_TopK_TopP(_generatorHandle, k, p, temperature)); - } - public ReadOnlySpan GetSequence(ulong index) { ulong sequenceLength = NativeMethods.OgaGenerator_GetSequenceLength(_generatorHandle, (UIntPtr)index).ToUInt64(); diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index 039dfb4de..552c9046a 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -80,29 +80,6 @@ internal class NativeLib [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken(IntPtr /* OgaGenerator* */ generator); - // This function is used to generate the next token in the sequence using the greedy search algorithm. - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_Top(IntPtr /* OgaGenerator* */ generator); - - // This function is used to generate the next token in the sequence using the greedy search algorithm. - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_TopK(IntPtr /* OgaGenerator* */ generator, - int /* int32_t */ k, - float /* single_t */ t); - - // This function is used to generate the next token in the sequence using the greedy search algorithm. - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_TopP(IntPtr /* OgaGenerator* */ generator, - float /* single_t */ p, - float /* single_t */ t); - - // This function is used to generate the next token in the sequence using the greedy search algorithm. - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_TopK_TopP(IntPtr /* OgaGenerator* */ generator, - int /* int32_t */ k, - float /* single_t */ p, - float /* single_t */ t); - // This function returns the length of the sequence at the given index. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern UIntPtr /* size_t */ OgaGenerator_GetSequenceLength(IntPtr /* const OgaGenerator* */ generator, diff --git a/src/generators.cpp b/src/generators.cpp index 6844a9aaf..cee8b1b02 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -84,7 +84,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ void Generator::ComputeLogits() { if (computed_logits_) - throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken* first"); + throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); search_->SetLogits(state_->Run(search_->GetSequenceLength(), search_->GetNextTokens(), search_->GetNextIndices())); computed_logits_ = true; @@ -101,46 +101,37 @@ bool Generator::IsDone() const { return search_->IsDone(); } -void Generator::GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature) { +void Generator::GenerateNextToken() { if (!computed_logits_) - throw std::runtime_error("Must call ComputeLogits before GenerateNextToken*"); + throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); computed_logits_ = false; - if (top_k == 1) { + auto& search = search_->params_->search; + if (!search.do_sample || search.top_k == 1) { search_->SelectTop(); return; } // The user explicitly called TopK_TopP on a beam search - if (search_->params_->search.num_beams != 1) + if (search.num_beams != 1) throw std::runtime_error("TopK and TopP cannot be used with a beam search"); // Sanity checks - if (top_p < 0.0f || top_p > 1.0f) + if (search.top_p < 0.0f || search.top_p > 1.0f) throw std::runtime_error("top_p must be between 0.0 and 1.0"); - if (top_k < 0) + if (search.top_k < 0) throw std::runtime_error("top_k must be 0 or greater"); - if (top_p > 0.0f && top_p < 1.0f && top_k > 1) { - search_->SampleTopKTopP(top_k, top_p, temperature); - } else if (top_k > 1) { - search_->SampleTopK(top_k, temperature); + if (search.top_p > 0.0f && search.top_p < 1.0f && search.top_k > 1) { + search_->SampleTopKTopP(search.top_k, search.top_p, search.temperature); + } else if (search.top_k > 1) { + search_->SampleTopK(search.top_k, search.temperature); } else { - assert(top_k == 0); - if (top_p == 0.0f) - throw std::runtime_error("top_k and top_p cannot both be zero"); - search_->SampleTopP(top_p, temperature); + assert(search.top_k == 0); + search_->SampleTopP(search.top_p, search.temperature); } } -void Generator::GenerateNextToken() { - auto& search = search_->params_->search; - if (search.do_sample) - GenerateNextToken_TopK_TopP(search.top_k, search.top_p, search.temperature); - else - GenerateNextToken_Top(); -} - RoamingArray Generator::GetSequence(int index) const { return search_->GetSequence(index); } diff --git a/src/generators.h b/src/generators.h index 1b42b45e9..3fb9f5201 100644 --- a/src/generators.h +++ b/src/generators.h @@ -100,10 +100,6 @@ struct Generator { bool IsDone() const; void ComputeLogits(); - void GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature); - void GenerateNextToken_TopP(float p, float temperature) { GenerateNextToken_TopK_TopP(0, p, temperature); } - void GenerateNextToken_TopK(int k, float temperature) { GenerateNextToken_TopK_TopP(k, 0.0f, temperature); } - void GenerateNextToken_Top() { GenerateNextToken_TopK_TopP(1, 0.0f, 0.0f); } void GenerateNextToken(); RoamingArray GetSequence(int index) const; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index bbf84be51..1beb2a43b 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -156,34 +156,6 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator* generator) { - OGA_TRY - reinterpret_cast(generator)->GenerateNextToken_Top(); - return nullptr; - OGA_CATCH -} - -OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK(OgaGenerator* generator, int k, float t) { - OGA_TRY - reinterpret_cast(generator)->GenerateNextToken_TopK(k, t); - return nullptr; - OGA_CATCH -} - -OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopP(OgaGenerator* generator, float p, float t) { - OGA_TRY - reinterpret_cast(generator)->GenerateNextToken_TopP(p, t); - return nullptr; - OGA_CATCH -} - -OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK_TopP(OgaGenerator* generator, int k, float p, float t) { - OGA_TRY - reinterpret_cast(generator)->GenerateNextToken_TopK_TopP(k, p, t); - return nullptr; - OGA_CATCH -} - size_t OGA_API_CALL OgaGenerator_GetSequenceLength(const OgaGenerator* oga_generator, size_t index) { auto& generator = *reinterpret_cast(oga_generator); return generator.GetSequence(static_cast(index)).GetCPU().size(); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index e702082fc..fbd394f10 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -172,23 +172,6 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); * \return OgaResult containing the error message if the computation of the logits failed. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); - -/* - * \brief Generates the next token based on the computed logits using the greedy search. - * \param[in] generator The generator to generate the next token for. - * \return OgaResult containing the error message if the generation of the next token failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator* generator); - -/* Top-K sampling: most probable words from the model's output probability distribution for the next word - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK(OgaGenerator* generator, int k, float t); - -/*Top-P sampling selects words from the smallest set of words whose cumulative probability exceeds a predefined threshold (p) - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopP(OgaGenerator* generator, float p, float t); - -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK_TopP(OgaGenerator* generator, int k, float p, float t); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); /* diff --git a/src/python/python.cpp b/src/python/python.cpp index 1c8db803d..584beb97c 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -137,22 +137,6 @@ struct PyGenerator { generator_->ComputeLogits(); } - void GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature) { - generator_->GenerateNextToken_TopK_TopP(top_k, top_p, temperature); - } - - void GenerateNextToken_TopP(float p, float temperature) { - generator_->GenerateNextToken_TopP(p, temperature); - } - - void GenerateNextToken_TopK(int k, float temperature) { - generator_->GenerateNextToken_TopK(k, temperature); - } - - void GenerateNextToken_Top() { - generator_->GenerateNextToken_Top(); - } - void GenerateNextToken() { generator_->GenerateNextToken(); } @@ -235,10 +219,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("is_done", &PyGenerator::IsDone) .def("compute_logits", &PyGenerator::ComputeLogits) .def("generate_next_token", &PyGenerator::GenerateNextToken) - .def("generate_next_token_top", &PyGenerator::GenerateNextToken_Top) - .def("generate_next_token_top_p", &PyGenerator::GenerateNextToken_TopP) - .def("generate_next_token_top_k", &PyGenerator::GenerateNextToken_TopK) - .def("generate_next_token_top_k_top_p", &PyGenerator::GenerateNextToken_TopK_TopP) .def("get_next_tokens", &PyGenerator::GetNextTokens) .def("get_sequence", &PyGenerator::GetSequence); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 2ac6bfb71..3a04a8180 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -187,6 +187,7 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { CheckResult(OgaCreateGeneratorParams(model, ¶ms)); OgaGeneratorParamsPtr params_ptr{params}; CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", max_length)); + CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", false)); CheckResult(OgaGeneratorParamsSetInputIDs(params, input_ids.data(), input_ids.size(), sequence_length, batch_size)); OgaGenerator* generator; @@ -195,7 +196,7 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { while (!OgaGenerator_IsDone(generator)) { CheckResult(OgaGenerator_ComputeLogits(generator)); - CheckResult(OgaGenerator_GenerateNextToken_Top(generator)); + CheckResult(OgaGenerator_GenerateNextToken(generator)); } // Verify outputs match expected outputs @@ -252,20 +253,11 @@ TEST(CAPITests, TopKCAPI) { CheckResult(OgaCreateGeneratorParams(model, ¶ms)); OgaGeneratorParamsPtr params_ptr{params}; CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40)); - CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences)); - - OgaGenerator* generator; - CheckResult(OgaCreateGenerator(model, params, &generator)); - OgaGeneratorPtr generator_ptr{generator}; - - while (!OgaGenerator_IsDone(generator)) { - CheckResult(OgaGenerator_ComputeLogits(generator)); - CheckResult(OgaGenerator_GenerateNextToken_TopK(generator, top_k, temp)); - } - CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_k", top_k)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp)); + CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences)); + OgaSequences* output_sequences; CheckResult(OgaGenerate(model, params, &output_sequences)); OgaSequencesPtr output_sequences_ptr{output_sequences}; @@ -310,20 +302,10 @@ TEST(CAPITests, TopPCAPI) { CheckResult(OgaCreateGeneratorParams(model, ¶ms)); OgaGeneratorParamsPtr params_ptr{params}; CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40)); - CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences)); - - OgaGenerator* generator; - CheckResult(OgaCreateGenerator(model, params, &generator)); - OgaGeneratorPtr generator_ptr{generator}; - - while (!OgaGenerator_IsDone(generator)) { - CheckResult(OgaGenerator_ComputeLogits(generator)); - CheckResult(OgaGenerator_GenerateNextToken_TopP(generator, top_p, temp)); - } - CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_p", top_p)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp)); + CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences)); OgaSequences* output_sequences; CheckResult(OgaGenerate(model, params, &output_sequences)); OgaSequencesPtr output_sequences_ptr{output_sequences}; @@ -369,21 +351,11 @@ TEST(CAPITests, TopKTopPCAPI) { CheckResult(OgaCreateGeneratorParams(model, ¶ms)); OgaGeneratorParamsPtr params_ptr{params}; CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40)); - CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences)); - - OgaGenerator* generator; - CheckResult(OgaCreateGenerator(model, params, &generator)); - OgaGeneratorPtr generator_ptr{generator}; - - while (!OgaGenerator_IsDone(generator)) { - CheckResult(OgaGenerator_ComputeLogits(generator)); - CheckResult(OgaGenerator_GenerateNextToken_TopK_TopP(generator, top_k, top_p, temp)); - } - CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_k", top_k)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_p", top_p)); CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp)); + CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences)); OgaSequences* output_sequences; CheckResult(OgaGenerate(model, params, &output_sequences)); OgaSequencesPtr output_sequences_ptr{output_sequences}; diff --git a/test/model_tests.cpp b/test/model_tests.cpp index a2b3a7832..79c1d2c64 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -44,7 +44,7 @@ TEST(ModelTests, GreedySearchGptFp32) { while (!generator->IsDone()) { generator->ComputeLogits(); - generator->GenerateNextToken_Top(); + generator->GenerateNextToken(); } // Verify outputs match expected outputs @@ -128,7 +128,7 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) while (!generator->IsDone()) { generator->ComputeLogits(); - generator->GenerateNextToken_Top(); + generator->GenerateNextToken(); } // Verify outputs match expected outputs @@ -226,7 +226,7 @@ Print all primes between 1 and n auto generator = Generators::CreateGenerator(*model, *params); while (!generator->IsDone()) { generator->ComputeLogits(); - generator->GenerateNextToken_Top(); + generator->GenerateNextToken(); } auto result = generator->GetSequence(0); diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 531270f78..239c71ab6 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -27,6 +27,8 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample=true; + params->search.top_p=0.25f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -37,7 +39,7 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) { generator->search_->SetLogits(logits_span); generator->computed_logits_ = true; // Verify outputs match expected outputs - generator->GenerateNextToken_TopP(0.25f, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); EXPECT_TRUE(0 == std::memcmp(output_span.data(), next_tokens.data(), expected_output.size() * sizeof(int32_t))); } @@ -53,6 +55,8 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = 2; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -64,8 +68,7 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { generator->computed_logits_ = true; // Verify outputs match expected outputs - int k = 2; - generator->GenerateNextToken_TopK(k, 1.0); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; @@ -85,6 +88,9 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = 2; + params->search.top_p = 0.25f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -95,9 +101,7 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; // Verify outputs match expected outputs - float p = 0.25f; - int k = 2; - generator->GenerateNextToken_TopK_TopP(k, p, 1.0); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; @@ -130,6 +134,8 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_p = 0.95f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -147,7 +153,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { auto logits_copy = logits_cpu; generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopP(0.95f, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { @@ -166,6 +172,8 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = k; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -183,7 +191,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { auto logits_copy=logits_cpu; generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopK(k, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { @@ -203,6 +211,9 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = k; + params->search.top_p = p; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -220,7 +231,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { auto logits_copy = logits_cpu; generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopK_TopP(k, p, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { @@ -248,6 +259,8 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_p = 0.25f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -259,7 +272,7 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) { generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), logits_cpu.size())); generator->computed_logits_ = true; // Verify outputs match expected outputs - generator->GenerateNextToken_TopP(0.25f, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); EXPECT_TRUE(0 == std::memcmp(output_span.data(), next_tokens.data(), expected_output.size() * sizeof(int32_t))); } @@ -276,6 +289,8 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = 2; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -287,8 +302,7 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), logits_cpu.size())); generator->computed_logits_ = true; // Verify outputs match expected outputs - int k = 2; - generator->GenerateNextToken_TopK(k, 1.0); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; @@ -309,6 +323,9 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = 2; + params->search.top_p = 0.25f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -320,9 +337,7 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), logits_cpu.size())); generator->computed_logits_ = true; // Verify outputs match expected outputs - float p = 0.25f; - int k = 2; - generator->GenerateNextToken_TopK_TopP(k, p, 1.0); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; @@ -338,6 +353,8 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_p = 0.95f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -358,7 +375,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { cudaMemcpyAsync(cpu_logits, logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopP(0.95f, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs @@ -378,6 +395,8 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = k; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -398,7 +417,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { auto generator = Generators::CreateGenerator(*model, *params); generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopK(k, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs @@ -419,6 +438,9 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = k; + params->search.top_p = p; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -439,7 +461,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { cudaMemcpyAsync(cpu_logits, logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopK_TopP(k, p, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs @@ -478,7 +500,7 @@ TEST(SamplingTests, RandomizedSamplingSelectTopCuda) { auto generator = Generators::CreateGenerator(*model, *params); generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); generator->computed_logits_ = true; - generator->GenerateNextToken_Top(); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs