diff --git a/src/config.h b/src/config.h index 6cc634658..2621edc21 100644 --- a/src/config.h +++ b/src/config.h @@ -79,7 +79,7 @@ struct Config { int num_return_sequences{1}; float repetition_penalty{1.0f}; // 1.0 means no penalty. int top_k{}; // Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model. - float top_p{1.0f}; // If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + float top_p{}; // If set to float >0 and <1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. float temperature{1.0f}; bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not. int no_repeat_ngram_size{}; diff --git a/src/csharp/Generator.cs b/src/csharp/Generator.cs index 1dc81883b..64c1c5623 100644 --- a/src/csharp/Generator.cs +++ b/src/csharp/Generator.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System; -using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntimeGenAI { @@ -26,9 +25,9 @@ public void ComputeLogits() Result.VerifySuccess(NativeMethods.OgaGenerator_ComputeLogits(_generatorHandle)); } - public void GenerateNextTokenTop() + public void GenerateNextToken() { - Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken_Top(_generatorHandle)); + Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken(_generatorHandle)); } public ReadOnlySpan GetSequence(ulong index) diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index 4b41102d7..552c9046a 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -78,7 +78,7 @@ internal class NativeLib // This function is used to generate the next token in the sequence using the greedy search algorithm. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken_Top(IntPtr /* OgaGenerator* */ generator); + public static extern IntPtr /* OgaResult* */ OgaGenerator_GenerateNextToken(IntPtr /* OgaGenerator* */ generator); // This function returns the length of the sequence at the given index. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] diff --git a/src/generators.cpp b/src/generators.cpp index 83021c652..cee8b1b02 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -84,7 +84,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ void Generator::ComputeLogits() { if (computed_logits_) - throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken* first"); + throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); search_->SetLogits(state_->Run(search_->GetSequenceLength(), search_->GetNextTokens(), search_->GetNextIndices())); computed_logits_ = true; @@ -101,46 +101,37 @@ bool Generator::IsDone() const { return search_->IsDone(); } -void Generator::GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature) { +void Generator::GenerateNextToken() { if (!computed_logits_) - throw std::runtime_error("Must call ComputeLogits before GenerateNextToken*"); + throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); computed_logits_ = false; - if (top_k == 1) { + auto& search = search_->params_->search; + if (!search.do_sample || search.top_k == 1) { search_->SelectTop(); return; } // The user explicitly called TopK_TopP on a beam search - if (search_->params_->search.num_beams != 1) + if (search.num_beams != 1) throw std::runtime_error("TopK and TopP cannot be used with a beam search"); // Sanity checks - if (top_p < 0.0f || top_p > 1.0f) + if (search.top_p < 0.0f || search.top_p > 1.0f) throw std::runtime_error("top_p must be between 0.0 and 1.0"); - if (top_k < 0) + if (search.top_k < 0) throw std::runtime_error("top_k must be 0 or greater"); - if (top_p > 0.0f && top_k > 1) { - search_->SampleTopKTopP(top_k, top_p, temperature); - } else if (top_k > 1) { - search_->SampleTopK(top_k, temperature); + if (search.top_p > 0.0f && search.top_p < 1.0f && search.top_k > 1) { + search_->SampleTopKTopP(search.top_k, search.top_p, search.temperature); + } else if (search.top_k > 1) { + search_->SampleTopK(search.top_k, search.temperature); } else { - assert(top_k == 0); - if (top_p == 0.0f) - throw std::runtime_error("top_k and top_p cannot both be zero"); - search_->SampleTopP(top_p, temperature); + assert(search.top_k == 0); + search_->SampleTopP(search.top_p, search.temperature); } } -void Generator::GenerateNextToken() { - auto& search = search_->params_->search; - if (search.do_sample) - GenerateNextToken_TopK_TopP(search.top_k, search.top_p, search.temperature); - else - GenerateNextToken_Top(); -} - RoamingArray Generator::GetSequence(int index) const { return search_->GetSequence(index); } diff --git a/src/generators.h b/src/generators.h index dd8edcf54..6f69ccb18 100644 --- a/src/generators.h +++ b/src/generators.h @@ -100,10 +100,6 @@ struct Generator { bool IsDone() const; void ComputeLogits(); - void GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature); - void GenerateNextToken_TopP(float p, float temperature) { GenerateNextToken_TopK_TopP(0, p, temperature); } - void GenerateNextToken_TopK(int k, float temperature) { GenerateNextToken_TopK_TopP(k, 0.0f, temperature); } - void GenerateNextToken_Top() { GenerateNextToken_TopK_TopP(1, 0.0f, 0.0f); } void GenerateNextToken(); RoamingArray GetSequence(int index) const; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 34e022578..1beb2a43b 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -149,9 +149,9 @@ OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator) { OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator* generator) { +OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) { OGA_TRY - reinterpret_cast(generator)->GenerateNextToken_Top(); + reinterpret_cast(generator)->GenerateNextToken(); return nullptr; OGA_CATCH } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 255bfbafb..fbd394f10 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -172,16 +172,7 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); * \return OgaResult containing the error message if the computation of the logits failed. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); - -/* - * \brief Generates the next token based on the computed logits using the greedy search. - * \param[in] generator The generator to generate the next token for. - * \return OgaResult containing the error message if the generation of the next token failed. - */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_Top(OgaGenerator* generator); - -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopK(OgaGenerator* generator, int k, float t); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken_TopP(OgaGenerator* generator, float p, float t); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); /* * \brief Returns the number of tokens in the sequence at the given index. diff --git a/src/python/python.cpp b/src/python/python.cpp index 1c8db803d..584beb97c 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -137,22 +137,6 @@ struct PyGenerator { generator_->ComputeLogits(); } - void GenerateNextToken_TopK_TopP(int top_k, float top_p, float temperature) { - generator_->GenerateNextToken_TopK_TopP(top_k, top_p, temperature); - } - - void GenerateNextToken_TopP(float p, float temperature) { - generator_->GenerateNextToken_TopP(p, temperature); - } - - void GenerateNextToken_TopK(int k, float temperature) { - generator_->GenerateNextToken_TopK(k, temperature); - } - - void GenerateNextToken_Top() { - generator_->GenerateNextToken_Top(); - } - void GenerateNextToken() { generator_->GenerateNextToken(); } @@ -235,10 +219,6 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("is_done", &PyGenerator::IsDone) .def("compute_logits", &PyGenerator::ComputeLogits) .def("generate_next_token", &PyGenerator::GenerateNextToken) - .def("generate_next_token_top", &PyGenerator::GenerateNextToken_Top) - .def("generate_next_token_top_p", &PyGenerator::GenerateNextToken_TopP) - .def("generate_next_token_top_k", &PyGenerator::GenerateNextToken_TopK) - .def("generate_next_token_top_k_top_p", &PyGenerator::GenerateNextToken_TopK_TopP) .def("get_next_tokens", &PyGenerator::GetNextTokens) .def("get_sequence", &PyGenerator::GetSequence); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index ab5bfc169..3a04a8180 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -187,6 +187,7 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { CheckResult(OgaCreateGeneratorParams(model, ¶ms)); OgaGeneratorParamsPtr params_ptr{params}; CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", max_length)); + CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", false)); CheckResult(OgaGeneratorParamsSetInputIDs(params, input_ids.data(), input_ids.size(), sequence_length, batch_size)); OgaGenerator* generator; @@ -195,7 +196,7 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { while (!OgaGenerator_IsDone(generator)) { CheckResult(OgaGenerator_ComputeLogits(generator)); - CheckResult(OgaGenerator_GenerateNextToken_Top(generator)); + CheckResult(OgaGenerator_GenerateNextToken(generator)); } // Verify outputs match expected outputs @@ -221,3 +222,153 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), max_length * sizeof(int32_t))); } } + +#if TEST_PHI2 +TEST(CAPITests, TopKCAPI) { + float top_k = 50; + float temp = 0.6f; + + OgaModel* model; + CheckResult(OgaCreateModel(MODEL_PATH "phi-2", &model)); + OgaModelPtr model_ptr{model}; + + OgaTokenizer* tokenizer; + CheckResult(OgaCreateTokenizer(model, &tokenizer)); + OgaTokenizerPtr tokenizer_ptr{tokenizer}; + + OgaSequences* input_sequences; + CheckResult(OgaCreateSequences(&input_sequences)); + OgaSequencesPtr sequences_ptr{input_sequences}; + + const char* input_strings[] = { + "This is a test.", + "Rats are awesome pets!", + "The quick brown fox jumps over the lazy dog.", + }; + + for (auto& string : input_strings) + CheckResult(OgaTokenizerEncode(tokenizer, string, input_sequences)); + + OgaGeneratorParams* params; + CheckResult(OgaCreateGeneratorParams(model, ¶ms)); + OgaGeneratorParamsPtr params_ptr{params}; + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40)); + CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true)); + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_k", top_k)); + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp)); + CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences)); + + OgaSequences* output_sequences; + CheckResult(OgaGenerate(model, params, &output_sequences)); + OgaSequencesPtr output_sequences_ptr{output_sequences}; + + // Decode The Batch + for (size_t i = 0; i < OgaSequencesCount(output_sequences); i++) { + std::span sequence{OgaSequencesGetSequenceData(output_sequences, i), OgaSequencesGetSequenceCount(output_sequences, i)}; + + const char* out_string; + CheckResult(OgaTokenizerDecode(tokenizer, sequence.data(), sequence.size(), &out_string)); + std::cout << "Decoded string:" << out_string << std::endl; + OgaDestroyString(out_string); + } +} + +TEST(CAPITests, TopPCAPI) { + float top_p = 0.6f; + float temp = 0.6f; + + OgaModel* model; + CheckResult(OgaCreateModel(MODEL_PATH "phi-2", &model)); + OgaModelPtr model_ptr{model}; + + OgaTokenizer* tokenizer; + CheckResult(OgaCreateTokenizer(model, &tokenizer)); + OgaTokenizerPtr tokenizer_ptr{tokenizer}; + + OgaSequences* input_sequences; + CheckResult(OgaCreateSequences(&input_sequences)); + OgaSequencesPtr sequences_ptr{input_sequences}; + + const char* input_strings[] = { + "This is a test.", + "Rats are awesome pets!", + "The quick brown fox jumps over the lazy dog.", + }; + + for (auto& string : input_strings) + CheckResult(OgaTokenizerEncode(tokenizer, string, input_sequences)); + + OgaGeneratorParams* params; + CheckResult(OgaCreateGeneratorParams(model, ¶ms)); + OgaGeneratorParamsPtr params_ptr{params}; + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40)); + CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true)); + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_p", top_p)); + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp)); + CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences)); + OgaSequences* output_sequences; + CheckResult(OgaGenerate(model, params, &output_sequences)); + OgaSequencesPtr output_sequences_ptr{output_sequences}; + + // Decode The Batch + for (size_t i = 0; i < OgaSequencesCount(output_sequences); i++) { + std::span sequence{OgaSequencesGetSequenceData(output_sequences, i), OgaSequencesGetSequenceCount(output_sequences, i)}; + + const char* out_string; + CheckResult(OgaTokenizerDecode(tokenizer, sequence.data(), sequence.size(), &out_string)); + std::cout << "Decoded string:" << out_string << std::endl; + OgaDestroyString(out_string); + } +} + +TEST(CAPITests, TopKTopPCAPI) { + float top_p = 0.6f; + int top_k = 50; + float temp = 0.6f; + + OgaModel* model; + CheckResult(OgaCreateModel(MODEL_PATH "phi-2", &model)); + OgaModelPtr model_ptr{model}; + + OgaTokenizer* tokenizer; + CheckResult(OgaCreateTokenizer(model, &tokenizer)); + OgaTokenizerPtr tokenizer_ptr{tokenizer}; + + OgaSequences* input_sequences; + CheckResult(OgaCreateSequences(&input_sequences)); + OgaSequencesPtr sequences_ptr{input_sequences}; + + const char* input_strings[] = { + "This is a test.", + "Rats are awesome pets!", + "The quick brown fox jumps over the lazy dog.", + }; + + for (auto& string : input_strings) + CheckResult(OgaTokenizerEncode(tokenizer, string, input_sequences)); + + OgaGeneratorParams* params; + CheckResult(OgaCreateGeneratorParams(model, ¶ms)); + OgaGeneratorParamsPtr params_ptr{params}; + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "max_length", 40)); + CheckResult(OgaGeneratorParamsSetSearchBool(params, "do_sample", true)); + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_k", top_k)); + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "top_p", top_p)); + CheckResult(OgaGeneratorParamsSetSearchNumber(params, "temperature", temp)); + CheckResult(OgaGeneratorParamsSetInputSequences(params, input_sequences)); + OgaSequences* output_sequences; + CheckResult(OgaGenerate(model, params, &output_sequences)); + OgaSequencesPtr output_sequences_ptr{output_sequences}; + + // Decode The Batch + for (size_t i = 0; i < OgaSequencesCount(output_sequences); i++) { + std::span sequence{OgaSequencesGetSequenceData(output_sequences, i), OgaSequencesGetSequenceCount(output_sequences, i)}; + + const char* out_string; + CheckResult(OgaTokenizerDecode(tokenizer, sequence.data(), sequence.size(), &out_string)); + std::cout << "Decoded string:" << out_string << std::endl; + OgaDestroyString(out_string); + } +} + +#endif // TEST_PHI2 diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs index cd3ba82d1..156f943b4 100644 --- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs +++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs @@ -8,6 +8,7 @@ using Microsoft.ML.OnnxRuntimeGenAI; using System.Collections.Generic; using System.Linq; +using System.Reflection.Emit; namespace Microsoft.ML.OnnxRuntimeGenAI.Tests { @@ -20,6 +21,20 @@ public OnnxRuntimeGenAITests(ITestOutputHelper o) this.output = o; } + private class IgnoreOnModelAbsebceFact : FactAttribute + { + public IgnoreOnModelAbsebceFact() + { + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "cpu", "phi-2"); + bool exists = System.IO.Directory.Exists(modelPath); + if (!System.IO.Directory.Exists(modelPath)) + { + // Skip this test on some machines since the model cannot be downloaded on those machines at runtime. + Skip = "Skipping this test since the model does not exist."; + } + } + } + [Fact(DisplayName = "TestGreedySearch")] public void TestGreedySearch() { @@ -49,7 +64,7 @@ public void TestGreedySearch() while (!generator.IsDone()) { generator.ComputeLogits(); - generator.GenerateNextTokenTop(); + generator.GenerateNextToken(); } for (ulong i = 0; i < batchSize; i++) @@ -72,10 +87,138 @@ public void TestGreedySearch() } } - [Fact(DisplayName = "TestTokenizerBatchEncodeDecode")] + [IgnoreOnModelAbsebceFact(DisplayName = "TestTopKSearch")] + public void TestTopKSearch() + { + int topK = 100; + float temp = 0.6f; + ulong maxLength = 20; + + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "cpu", "phi-2"); + using (var model = new Model(modelPath)) + { + Assert.NotNull(model); + using (var tokenizer = new Tokenizer(model)) + { + Assert.NotNull(tokenizer); + + var strings = new string[] { + "This is a test.", + "Rats are awesome pets!", + "The quick brown fox jumps over the lazy dog." + }; + + var sequences = tokenizer.EncodeBatch(strings); + Assert.NotNull(sequences); + Assert.Equal((ulong)strings.Length, sequences.NumSequences); + + using GeneratorParams generatorParams = new GeneratorParams(model); + Assert.NotNull(generatorParams); + + generatorParams.SetInputSequences(sequences); + generatorParams.SetSearchOption("max_length", maxLength); + generatorParams.SetSearchOption("do_sample", true); + generatorParams.SetSearchOption("top_k", topK); + generatorParams.SetSearchOption("temperature", temp); + var outputSequences = model.Generate(generatorParams); + Assert.NotNull(outputSequences); + + var outputStrings = tokenizer.DecodeBatch(outputSequences); + Assert.NotNull(outputStrings); + } + } + } + + [IgnoreOnModelAbsebceFact(DisplayName = "TestTopPSearch")] + public void TestTopPSearch() + { + float topP = 0.6f; + float temp = 0.6f; + ulong maxLength = 20; + + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "cpu", "phi-2"); + using (var model = new Model(modelPath)) + { + Assert.NotNull(model); + using (var tokenizer = new Tokenizer(model)) + { + Assert.NotNull(tokenizer); + + var strings = new string[] { + "This is a test.", + "Rats are awesome pets!", + "The quick brown fox jumps over the lazy dog." + }; + + var sequences = tokenizer.EncodeBatch(strings); + Assert.NotNull(sequences); + Assert.Equal((ulong)strings.Length, sequences.NumSequences); + + using GeneratorParams generatorParams = new GeneratorParams(model); + Assert.NotNull(generatorParams); + + generatorParams.SetInputSequences(sequences); + generatorParams.SetSearchOption("max_length", maxLength); + generatorParams.SetSearchOption("do_sample", true); + generatorParams.SetSearchOption("top_p", topP); + generatorParams.SetSearchOption("temperature", temp); + var outputSequences = model.Generate(generatorParams); + Assert.NotNull(outputSequences); + + var outputStrings = tokenizer.DecodeBatch(outputSequences); + Assert.NotNull(outputStrings); + } + } + } + + [IgnoreOnModelAbsebceFact(DisplayName = "TestTopKTopPSearch")] + public void TestTopKTopPSearch() + { + int topK = 100; + float topP = 0.6f; + float temp = 0.6f; + ulong maxLength = 20; + + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "cpu", "phi-2"); + using (var model = new Model(modelPath)) + { + Assert.NotNull(model); + using (var tokenizer = new Tokenizer(model)) + { + Assert.NotNull(tokenizer); + + var strings = new string[] { + "This is a test.", + "Rats are awesome pets!", + "The quick brown fox jumps over the lazy dog." + }; + + var sequences = tokenizer.EncodeBatch(strings); + Assert.NotNull(sequences); + Assert.Equal((ulong)strings.Length, sequences.NumSequences); + + using GeneratorParams generatorParams = new GeneratorParams(model); + Assert.NotNull(generatorParams); + + generatorParams.SetInputSequences(sequences); + generatorParams.SetSearchOption("max_length", maxLength); + generatorParams.SetSearchOption("do_sample", true); + generatorParams.SetSearchOption("top_k", topK); + generatorParams.SetSearchOption("top_p", topP); + generatorParams.SetSearchOption("temperature", temp); + var outputSequences = model.Generate(generatorParams); + Assert.NotNull(outputSequences); + + var outputStrings = tokenizer.DecodeBatch(outputSequences); + Assert.NotNull(outputStrings); + } + } + } + + [IgnoreOnModelAbsebceFact(DisplayName = "TestTokenizerBatchEncodeDecode")] public void TestTokenizerBatchEncodeDecode() { - string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "hf-internal-testing", "tiny-random-gpt2-fp32"); + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "cpu", "phi-2"); using (var model = new Model(modelPath)) { Assert.NotNull(model); @@ -101,10 +244,10 @@ public void TestTokenizerBatchEncodeDecode() } } - [Fact(DisplayName = "TestTokenizerBatchEncodeSingleDecode")] + [IgnoreOnModelAbsebceFact(DisplayName = "TestTokenizerBatchEncodeSingleDecode")] public void TestTokenizerBatchEncodeSingleDecode() { - string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "hf-internal-testing", "tiny-random-gpt2-fp32"); + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "cpu", "phi-2"); using (var model = new Model(modelPath)) { Assert.NotNull(model); @@ -132,10 +275,10 @@ public void TestTokenizerBatchEncodeSingleDecode() } } - [Fact(DisplayName = "TestTokenizerBatchEncodeStreamDecode")] + [IgnoreOnModelAbsebceFact(DisplayName = "TestTokenizerBatchEncodeStreamDecode")] public void TestTokenizerBatchEncodeStreamDecode() { - string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "hf-internal-testing", "tiny-random-gpt2-fp32"); + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "cpu", "phi-2"); using (var model = new Model(modelPath)) { Assert.NotNull(model); @@ -168,10 +311,10 @@ public void TestTokenizerBatchEncodeStreamDecode() } } - [Fact(DisplayName = "TestTokenizerSingleEncodeDecode")] + [IgnoreOnModelAbsebceFact(DisplayName = "TestTokenizerSingleEncodeDecode")] public void TestTokenizerSingleEncodeDecode() { - string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "hf-internal-testing", "tiny-random-gpt2-fp32"); + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "cpu", "phi-2"); using (var model = new Model(modelPath)) { Assert.NotNull(model); @@ -192,10 +335,10 @@ public void TestTokenizerSingleEncodeDecode() } } - [Fact(Skip = "Phi-2 is not available in the CI pipeline")] + [IgnoreOnModelAbsebceFact(DisplayName = "TestPhi2")] public void TestPhi2() { - string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "phi-2"); + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "test_models", "cpu", "phi-2"); using (var model = new Model(modelPath)) { Assert.NotNull(model); diff --git a/test/model_tests.cpp b/test/model_tests.cpp index a2b3a7832..79c1d2c64 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -44,7 +44,7 @@ TEST(ModelTests, GreedySearchGptFp32) { while (!generator->IsDone()) { generator->ComputeLogits(); - generator->GenerateNextToken_Top(); + generator->GenerateNextToken(); } // Verify outputs match expected outputs @@ -128,7 +128,7 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) while (!generator->IsDone()) { generator->ComputeLogits(); - generator->GenerateNextToken_Top(); + generator->GenerateNextToken(); } // Verify outputs match expected outputs @@ -226,7 +226,7 @@ Print all primes between 1 and n auto generator = Generators::CreateGenerator(*model, *params); while (!generator->IsDone()) { generator->ComputeLogits(); - generator->GenerateNextToken_Top(); + generator->GenerateNextToken(); } auto result = generator->GetSequence(0); diff --git a/test/sampling_tests.cpp b/test/sampling_tests.cpp index 531270f78..239c71ab6 100644 --- a/test/sampling_tests.cpp +++ b/test/sampling_tests.cpp @@ -27,6 +27,8 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample=true; + params->search.top_p=0.25f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -37,7 +39,7 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) { generator->search_->SetLogits(logits_span); generator->computed_logits_ = true; // Verify outputs match expected outputs - generator->GenerateNextToken_TopP(0.25f, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); EXPECT_TRUE(0 == std::memcmp(output_span.data(), next_tokens.data(), expected_output.size() * sizeof(int32_t))); } @@ -53,6 +55,8 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = 2; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -64,8 +68,7 @@ TEST(SamplingTests, BatchedSamplingTopKCpu) { generator->computed_logits_ = true; // Verify outputs match expected outputs - int k = 2; - generator->GenerateNextToken_TopK(k, 1.0); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; @@ -85,6 +88,9 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = 2; + params->search.top_p = 0.25f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -95,9 +101,7 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCpu) { generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; // Verify outputs match expected outputs - float p = 0.25f; - int k = 2; - generator->GenerateNextToken_TopK_TopP(k, p, 1.0); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; @@ -130,6 +134,8 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_p = 0.95f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -147,7 +153,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCpu) { auto logits_copy = logits_cpu; generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopP(0.95f, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { @@ -166,6 +172,8 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = k; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -183,7 +191,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCpu) { auto logits_copy=logits_cpu; generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopK(k, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { @@ -203,6 +211,9 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = k; + params->search.top_p = p; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -220,7 +231,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCpu) { auto logits_copy = logits_cpu; generator->search_->SetLogits(Generators::cpu_span(logits_copy)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopK_TopP(k, p, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); // Verify outputs match expected outputs for (int b = 0; b < batch_size; b++) { @@ -248,6 +259,8 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_p = 0.25f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -259,7 +272,7 @@ TEST(SamplingTests, BatchedSamplingTopPCuda) { generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), logits_cpu.size())); generator->computed_logits_ = true; // Verify outputs match expected outputs - generator->GenerateNextToken_TopP(0.25f, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); EXPECT_TRUE(0 == std::memcmp(output_span.data(), next_tokens.data(), expected_output.size() * sizeof(int32_t))); } @@ -276,6 +289,8 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = 2; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -287,8 +302,7 @@ TEST(SamplingTests, BatchedSamplingTopKCuda) { generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), logits_cpu.size())); generator->computed_logits_ = true; // Verify outputs match expected outputs - int k = 2; - generator->GenerateNextToken_TopK(k, 1.0); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; @@ -309,6 +323,9 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { int batch_size = 4; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = 2; + params->search.top_p = 0.25f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -320,9 +337,7 @@ TEST(SamplingTests, BatchedSamplingTopPAndKCuda) { generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), logits_cpu.size())); generator->computed_logits_ = true; // Verify outputs match expected outputs - float p = 0.25f; - int k = 2; - generator->GenerateNextToken_TopK_TopP(k, p, 1.0); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); for (int b = 0; b < batch_size; b++) { auto next_token = next_tokens[b]; @@ -338,6 +353,8 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_p = 0.95f; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -358,7 +375,7 @@ TEST(SamplingTests, RandomizedSamplingTopPCuda) { cudaMemcpyAsync(cpu_logits, logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopP(0.95f, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs @@ -378,6 +395,8 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = k; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -398,7 +417,7 @@ TEST(SamplingTests, RandomizedSamplingTopKCuda) { auto generator = Generators::CreateGenerator(*model, *params); generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopK(k, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs @@ -419,6 +438,9 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { std::vector input_ids{0, 1, 2, 3, 4}; auto params = Generators::CreateGeneratorParams(); params->search.max_length = 10; + params->search.do_sample = true; + params->search.top_k = k; + params->search.top_p = p; params->batch_size = batch_size; params->sequence_length = 1; params->vocab_size = vocab_size; @@ -439,7 +461,7 @@ TEST(SamplingTests, RandomizedSamplingTopPAndKCuda) { cudaMemcpyAsync(cpu_logits, logits_gpu.get(), vocab_size * batch_size * sizeof(float), cudaMemcpyDeviceToHost, params->cuda_stream); generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); generator->computed_logits_ = true; - generator->GenerateNextToken_TopK_TopP(k, p, 1.0f); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs @@ -478,7 +500,7 @@ TEST(SamplingTests, RandomizedSamplingSelectTopCuda) { auto generator = Generators::CreateGenerator(*model, *params); generator->search_->SetLogits(Generators::gpu_span(logits_gpu.get(), vocab_size * batch_size)); generator->computed_logits_ = true; - generator->GenerateNextToken_Top(); + generator->GenerateNextToken(); auto next_tokens = generator->search_->GetNextTokens().GetCPU(); cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs