From e41fb2c1430693157c3cea70d3bdcc207ae84ef3 Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Thu, 11 Jul 2024 09:11:06 -0700 Subject: [PATCH] Fix Beam Search for CPU and CUDA... Also include it in our API (#669) ### Summary Beam Search was hanging and not outputting correct results. Furthermore, it did not fit into our API design. This PR addresses the correctness and API problems with Beam search. We plan to improve both CPU and CUDA performance and memory efficiency in the near future. ### Issues Addressed Here is a quick summary of the issues addressed by this PR... these apply to both CPU and CUDA implementations: - No log-softmax normalization was performed before adding beam scores. This caused faulty outputs which did not match the ORT implementation. - The `is_done` flag was not set or checked properly in the case of EOS token or `max_sequence_length`. This caused hanging, infinite looping, and memory buffer overflow. This sometimes gave the impression of bad performance, while in reality it was a correctness issue. - `Finalize` was not called automatically. If a user didn't call it manually this could cause a floating point exception or other fault. - There was no easy way to get output from Beam Search. `Finalize` was clunky and unintuitive as it didn't fit with our API. - Our testing file was not up to date with our latest APIs. ### API Changes Given the issues with `Finalize`, this PR introduces an update to the way Beam Search fits into our API. The user no longer has to manually call `Finalize` in order to access the Beam Search results. These are returned automatically by the `Generate()` function and can be accessed using batch beam indexing. --------- Co-authored-by: Baiju Meswani --- src/beam_search_scorer.cpp | 59 ++++-------------------- src/beam_search_scorer.h | 20 ++++---- src/beam_search_scorer_cuda.cpp | 17 +++++-- src/beam_search_scorer_cuda.cu | 81 ++++++++++++++------------------- src/beam_search_scorer_cuda.cuh | 20 ++++---- src/beam_search_scorer_cuda.h | 5 +- src/cuda_sampling.cu | 3 +- src/cuda_sampling.cuh | 3 ++ src/generators.cpp | 3 +- src/search.cpp | 67 +++++++++++++++++---------- src/search.h | 17 ++++--- src/search_cuda.cpp | 42 +++++++++++++---- src/search_cuda.h | 8 +++- src/smartptrs.h | 6 +++ src/softmax.h | 27 ++++++++++- test/model_tests.cpp | 39 ++++------------ 16 files changed, 213 insertions(+), 204 deletions(-) diff --git a/src/beam_search_scorer.cpp b/src/beam_search_scorer.cpp index a9056856f..ec9760377 100644 --- a/src/beam_search_scorer.cpp +++ b/src/beam_search_scorer.cpp @@ -16,7 +16,7 @@ void BeamHypotheses::Init(float length_penalty, std::span beams done_ = false; } -void BeamHypotheses::Add(std::span hypothesis, float sum_logprobs) { +void BeamHypotheses::Add(cpu_span hypothesis, float sum_logprobs) { auto length = hypothesis.size(); float const score = sum_logprobs / std::pow(static_cast(length), length_penalty_); @@ -43,28 +43,6 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con return beams_.back().score < current_score; } -void BeamHypotheses::Output( - size_t top_k, - size_t max_length, - std::span sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) - std::span sequences_scores) const // buffer of shape (num_return_sequences) or empty -{ - // Copy the top_k beams into the sequences - assert(top_k <= beams_used_); - for (int index = 0; index < top_k; index++) { - auto& item = beams_[index]; - std::span const target = sequences.subspan(index * max_length, max_length); - - // Note that word_ids might be less than max_length. - // Since the sequences has been filled with pad token ID, so padding is not needed here. - copy(item.hypothesis, target); - - if (!sequences_scores.empty()) { - sequences_scores[index] = item.score; - } - } -} - BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters) : batch_size_{parameters.batch_size}, num_beams_{parameters.search.num_beams}, @@ -110,7 +88,7 @@ void BeamSearchScorer::Process(Sequences& sequences, // It contains word ID of whole sequence generated so far. // It is different from subgraph input_ids, which only need one word when past state is not empty. - const int sequence_length = sequences.GetSequenceLength(); + size_t sequence_length = static_cast(sequences.GetSequenceLength()); assert(next_scores.size() == next_tokens.size()); assert(next_scores.size() == next_indices.size()); @@ -146,11 +124,12 @@ void BeamSearchScorer::Process(Sequences& sequences, } // Clone the sequence and append to buffer. - std::span const src = sequences.GetSequence(batch_beam_idx); - auto clone = hypothesis_buffer_.subspan(static_cast(hypothesis_buffer_used_), sequence_length); + cpu_span const src{sequences.GetSequence(batch_beam_idx)}; + auto clone_span = hypothesis_buffer_.subspan(static_cast(hypothesis_buffer_used_), sequence_length); + cpu_span clone{clone_span.data(), sequence_length}; copy(src, clone); - hypothesis_buffer_used_ += sequence_length; + hypothesis_buffer_used_ += static_cast(sequence_length); beam_hyp.Add(clone, next_score); } else { // Add next predicted token since it is not eos_token. @@ -177,7 +156,7 @@ void BeamSearchScorer::Process(Sequences& sequences, if (!early_stopping_) { std::span const topk_scores = next_scores.subspan(batch * num_beams_, top_k); const auto best_sum_logprobs = std::max_element(topk_scores.begin(), topk_scores.end()); - if (beam_hyp.CanImprove(*best_sum_logprobs, sequence_length)) { + if (beam_hyp.CanImprove(*best_sum_logprobs, static_cast(sequence_length))) { continue; } } @@ -188,12 +167,7 @@ void BeamSearchScorer::Process(Sequences& sequences, } void BeamSearchScorer::Finalize(Sequences& sequences, - size_t num_return_sequences, - cpu_span output, - cpu_span sequence_scores) { - // output is Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). - // sequence_scores is the optional Score of each sequence, with shape (batch_size * num_return_sequences). - + size_t num_return_sequences) { // Finalize all open beam hypotheses and add to generated hypotheses. for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; @@ -208,23 +182,6 @@ void BeamSearchScorer::Finalize(Sequences& sequences, beam_hyp.Add(final_tokens, final_score); } } - - // Fill output sequences with pad token ID so that we do not need append it later. - std::fill_n(output.data(), output.size(), pad_token_id_); - - // Select the best hypotheses according to number of sequences to return. - for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; - - auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, - num_return_sequences * max_length_); - std::span sequence_scores_buffer; - if (!sequence_scores.empty()) { - sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences); - } - - beam_hyp.Output(num_return_sequences, max_length_, batch_output, sequence_scores_buffer); - } } } // namespace Generators diff --git a/src/beam_search_scorer.h b/src/beam_search_scorer.h index 22652437d..b01fed6d0 100644 --- a/src/beam_search_scorer.h +++ b/src/beam_search_scorer.h @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - +#include "sequences.h" +#pragma once // The implementation is based on huggingface transformers generation_beam_search.py namespace Generators { struct HypothesisScore { - std::span hypothesis; + cpu_span hypothesis; float score; }; @@ -14,16 +15,14 @@ struct BeamHypotheses { void Init(float length_penalty, std::span beams); // Add a new hypothesis - void Add(std::span hypothesis, float sum_logprobs); + void Add(cpu_span hypothesis, float sum_logprobs); // Return true if this beats the worst score in the hypothesis bool CanImprove(float best_sum_logprobs, int current_length) const; - // Output results - void Output(size_t top_k, // number of sequences to return - size_t max_length, // max sequence length - std::span sequences, // buffer with pad token, shape (num_return_sequences, max_length) - std::span sequences_scores) const; // buffer for sequence scores, with shape (num_return_sequences) + RoamingArray GetHypothesis(size_t index) const { return beams_[index].hypothesis; } + + // TODO(aciddelgado): Methods to get all hypotheses and scores std::span beams_; // Beam width sized array of hypotheses, sorted by highest scoring int beams_used_; // Number of elements used in beams_ @@ -40,15 +39,14 @@ struct BeamSearchScorer { std::span next_indices); void Finalize(Sequences& sequences, - size_t num_return_sequences, - cpu_span output_sequences, - cpu_span output_sequence_scores); + size_t num_return_sequences); bool IsDone() const { return not_done_count_ == 0; } cpu_span GetNextScores() { return next_beam_scores_; } cpu_span GetNextTokens() { return next_beam_tokens_; } cpu_span GetNextIndicesCPU() { return next_beam_indices_; } + BeamHypotheses GetBeamHypotheses(size_t batch_id) { return beam_hyps_[batch_id]; } private: int batch_size_; diff --git a/src/beam_search_scorer_cuda.cpp b/src/beam_search_scorer_cuda.cpp index f43eac1d0..4c48ed82a 100644 --- a/src/beam_search_scorer_cuda.cpp +++ b/src/beam_search_scorer_cuda.cpp @@ -76,11 +76,18 @@ bool BeamSearchScorer_Cuda::IsDoneLater() const { } void BeamSearchScorer_Cuda::Finalize(Sequences_Cuda& sequences, - size_t num_return_sequences, - std::span output, // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length) - std::span sequence_scores) { // Score of each sequence, with shape (batch_size * num_return_sequences). - assert(!output.empty()); - cuda::LaunchBeamSearchScorer_Finalize(state_cpu_->batch_size_, *state_gpu_, sequences.GetSequences(), sequences.GetSequenceLength(), beam_hyps_, next_beam_scores_, output, sequence_scores, stream_); + size_t num_return_sequences) { + cuda::LaunchBeamSearchScorer_Finalize(state_cpu_->batch_size_, *state_gpu_, sequences.GetSequences(), sequences.GetSequenceLength(), beam_hyps_, next_beam_scores_, stream_); +} + +RoamingArray BeamSearchScorer_Cuda::GetBeamHypothesis(size_t batch_id, size_t beam_id) const { + cuda_host_unique_ptr hypothesis_ptr = CudaMallocHostArray(1); + cuda_host_unique_ptr hypothesis_length = CudaMallocHostArray(1); + cuda_host_unique_ptr hypothesis_score = CudaMallocHostArray(1); + cuda::LaunchBeamSearchScorer_GetHypothesisPtr(batch_id, beam_id, beam_hyps_, hypothesis_ptr.get(), hypothesis_length.get(), hypothesis_score.get(), stream_); + CudaCheck() == cudaStreamSynchronize(stream_); + std::span hypothesis_span(*hypothesis_ptr.get(), *hypothesis_length.get()); + return gpu_span{hypothesis_span.data(), hypothesis_span.size()}; } } // namespace Generators diff --git a/src/beam_search_scorer_cuda.cu b/src/beam_search_scorer_cuda.cu index bd491bbec..5673c66a2 100644 --- a/src/beam_search_scorer_cuda.cu +++ b/src/beam_search_scorer_cuda.cu @@ -71,30 +71,6 @@ __device__ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_ return beams_[beams_count_ - 1].score < current_score; } -__device__ void BeamHypotheses::Output( - int top_k, - int max_length, - int pad_token_id, - int32_t* sequences, // buffer of shape (num_return_sequences, max_length) - float* sequences_scores) // buffer of shape (num_return_sequences) or empty -{ - // Copy the top_k beams into the sequences - for (int index = 0; index < top_k; index++) { - auto& item = beams_[index]; - int32_t* target = sequences + index * max_length; - - // Note that word_ids might be less than max_length. - for (int i = 0; i < item.hypothesis_length; i++) - target[i] = item.hypothesis[i]; - // Pad remaining values with pad token id - for (int i = item.hypothesis_length; i < max_length; i++) - target[i] = pad_token_id; - - if (sequences_scores) - sequences_scores[index] = item.score; - } -} - __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu, BeamScorerState& state, const int32_t* sequences_buffer, @@ -132,7 +108,7 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu, continue; } - // Clone the sequence and append to buffer. + // Clone the sequence and append to buffer. // TODO(aciddelgado): why do we need to clone the sequence here? const int32_t* src = sequences_buffer + batch_beam_idx * state.max_length_; auto clone = hypothesis_buffer_ + atomicAdd(&state.hypothesis_buffer_used_, sequence_length); @@ -156,7 +132,7 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu, if (beam_hyp.beams_used_ == state.num_beams_) { if (state.early_stopping_ || !beam_hyp.CanImprove(*std::max_element(next_scores + batch_start, next_scores + batch_start + top_k), sequence_length)) { beam_hyp.done_ = true; - if (atomicAdd(&state.not_done_count_, -1) == 0) + if (atomicAdd(&state.not_done_count_, -1) == 1) state_cpu.not_done_count_ = 0; // Update the CPU side } } @@ -238,8 +214,7 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp block_size.x = batch_beam_size; block_size.y = sequence_length; } else { - if (sequence_length <= max_threads) // Sequence length fits into thread block, but batch_beam_size does not, so chunk it - { + if (sequence_length <= max_threads) { // Sequence length fits into thread block, but batch_beam_size does not, so chunk it block_size.x = max_threads / sequence_length; block_size.y = sequence_length; @@ -269,9 +244,7 @@ __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, const int32_t* sequences_buffer, int sequence_length, BeamHypotheses* beam_hyps_, - const float* final_beam_scores, - int32_t* output, - float* sequence_scores) { + const float* final_beam_scores) { int batch_index = blockIdx.x * blockDim.x + threadIdx.x; if (batch_index >= state.batch_size_) return; @@ -286,18 +259,6 @@ __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, beam_hyp.Add(final_tokens, sequence_length, final_score); } } - - int num_return_sequences = 1; - - // Select the best hypotheses according to number of sequences to return. - auto batch_output = output + batch_index * num_return_sequences * state.max_length_; - - beam_hyp.Output( - num_return_sequences, - state.max_length_, - state.pad_token_id_, - batch_output, - sequence_scores ? sequence_scores + batch_index * num_return_sequences : nullptr); } void LaunchBeamSearchScorer_Finalize(int batch_size, @@ -306,16 +267,40 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, int sequence_length, std::span beam_hyps, std::span final_beam_scores, - std::span output, - std::span sequence_scores, cudaStream_t stream) { BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state, sequences.data(), sequence_length, beam_hyps.data(), - final_beam_scores.data(), - output.data(), - sequence_scores.data()); + final_beam_scores.data()); +} + +__global__ void BeamSearchScorer_GetHypothesisPtr(size_t batch_id, + size_t beam_id, + BeamHypotheses* beam_hyps_data, + int32_t** hypothesis_ptr, + int* hypothesis_length, + float* hypothesis_score) { + auto& beam_hyp = beam_hyps_data[batch_id]; + auto& item = beam_hyp.beams_[beam_id]; + hypothesis_ptr[0] = const_cast(item.hypothesis); + hypothesis_length[0] = item.hypothesis_length; + hypothesis_score[0] = item.score; +} + +void LaunchBeamSearchScorer_GetHypothesisPtr(size_t batch_id, + size_t beam_id, + gpu_span beam_hyps, + int32_t** hypothesis_ptr, + int* hypothesis_length, + float* hypothesis_score, + cudaStream_t stream) { + BeamSearchScorer_GetHypothesisPtr<<<1, 1, 0, stream>>>(batch_id, + beam_id, + beam_hyps.data(), + hypothesis_ptr, + hypothesis_length, + hypothesis_score); } __global__ void InitScoresKernel(float* beam_scores, diff --git a/src/beam_search_scorer_cuda.cuh b/src/beam_search_scorer_cuda.cuh index 66812b159..7bd864d80 100644 --- a/src/beam_search_scorer_cuda.cuh +++ b/src/beam_search_scorer_cuda.cuh @@ -1,3 +1,5 @@ +#include "smartptrs.h" + namespace Generators { namespace cuda { @@ -19,13 +21,6 @@ struct BeamHypotheses { // Return true if this beats the worst score in the hypothesis __device__ bool CanImprove(float best_sum_logprobs, int current_length) const; - - // Output results - __device__ void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - int pad_token_id, // pad token - int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) - float* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) }; struct BeamScorerState { @@ -71,10 +66,17 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, int sequence_length, std::span beam_hyps_, std::span final_beam_scores, - std::span output, - std::span sequence_scores, cudaStream_t stream); +// Since we need to index through a couple layers of GPU memory, we need to provide a way to get the pointers +void LaunchBeamSearchScorer_GetHypothesisPtr(size_t batch_id, + size_t beam_id, + gpu_span beam_hyps, + int32_t** hypothesis_ptr, + int* hypothesis_length, + float* hypothesis_score, + cudaStream_t stream); + void LaunchInitScoresKernel(float* beam_scores, int batch_size, int num_beams, diff --git a/src/beam_search_scorer_cuda.h b/src/beam_search_scorer_cuda.h index 06cb69d02..477e9b676 100644 --- a/src/beam_search_scorer_cuda.h +++ b/src/beam_search_scorer_cuda.h @@ -9,9 +9,7 @@ struct BeamSearchScorer_Cuda { std::span next_indices); void Finalize(Sequences_Cuda& sequences, - size_t num_return_sequences, - std::span output_sequences, - std::span output_sequence_scores); + size_t num_return_sequences); bool IsDone() const { return false; } // For CUDA we speculatively run the next step while we wait for the GPU to report status. We use 'IsDoneLater()' for this bool IsDoneLater() const; @@ -24,6 +22,7 @@ struct BeamSearchScorer_Cuda { return next_beam_indices_cpu_; } gpu_span GetNextIndicesGPU() { return next_beam_indices_; } + RoamingArray GetBeamHypothesis(size_t batch_id, size_t beam_id) const; private: mutable cuda_event_holder event_process_complete_; diff --git a/src/cuda_sampling.cu b/src/cuda_sampling.cu index bef166d9f..1df3e70ee 100644 --- a/src/cuda_sampling.cu +++ b/src/cuda_sampling.cu @@ -299,7 +299,7 @@ __global__ void SoftmaxBlockForward(outscalar_t* output, scalar_t* input, int cl template void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements, - int input_stride, int output_stride, int batch_count, float temperature=1.0) { + int input_stride, int output_stride, int batch_count, float temperature) { dim3 grid(batch_count); constexpr int ILP = sizeof(float4) / sizeof(float); dim3 block = SoftmaxGetBlockSize(ILP, softmax_elements); @@ -313,6 +313,7 @@ void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const softmax_elements, input_stride, output_stride, temperature); } } +template void DispatchBlockwiseSoftmaxForward(cudaStream_t*, float*, const float*, int, int, int, int, float); // Populate Kernels and Launchers diff --git a/src/cuda_sampling.cuh b/src/cuda_sampling.cuh index cc8ab9867..6a206d5dd 100644 --- a/src/cuda_sampling.cuh +++ b/src/cuda_sampling.cuh @@ -23,5 +23,8 @@ struct SamplingData { void LaunchPopulateIndices(int* indices, int size, int batch_size, cudaStream_t stream); void GetSample(SamplingData* data, cudaStream_t stream, int32_t* d_next_token, float* d_scores, int vocab_size, int batch_size, int k, float p, float temperature); +template +void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature=1.0); + } // namespace cuda } // namespace Generators \ No newline at end of file diff --git a/src/generators.cpp b/src/generators.cpp index dfc457316..fe005111a 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -205,8 +205,7 @@ TokenSequences Generate(const Model& model, const GeneratorParams& params) { } TokenSequences result; - - for (int i = 0; i < params.batch_size; i++) { + for (int i = 0; i < params.batch_size * params.search.num_return_sequences; i++) { auto sequence = generator->search_->GetSequence(i); auto sequence_cpu = sequence.GetCPU(); diff --git a/src/search.cpp b/src/search.cpp index 0a2c2e488..d7a9d3c69 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -62,10 +62,17 @@ int Search_Cpu::GetSequenceLength() const { } void BeamSearch_Cpu::SelectTop() { + // Normalize next token scores + for (int i = 0; i < params_->BatchBeamSize(); i++) { + std::span const scores = next_token_scores_.subspan(static_cast(i) * static_cast(params_->vocab_size), params_->vocab_size); + LogSoftMax(scores, 1.0); + } + auto beam_scores = beam_scorer_->GetNextScores(); + // Add beam score to next token scores. Corresponding python code is like: // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) - // TODO(tianleiwu): use thread pool to parallel + // TODO(aciddelgado): use thread pool to parallel int offset = 0; int batch_beam_index = 0; for (int i = 0; i < params_->batch_size; i++) { @@ -76,7 +83,6 @@ void BeamSearch_Cpu::SelectTop() { } } - // TODO: Write output scores? const size_t top_k = 2 * params_->search.num_beams; struct ScoreIndex { @@ -86,14 +92,15 @@ void BeamSearch_Cpu::SelectTop() { bool operator<(const ScoreIndex& s) const { return score < s.score; } }; - auto scores = std::make_unique(top_k * params_->batch_size); - auto indices = std::make_unique(top_k * params_->batch_size); - auto tokens = std::make_unique(top_k * params_->batch_size); + auto scores = std::make_unique(top_k * params_->batch_size); // Score of top_k tokens + auto indices = std::make_unique(top_k * params_->batch_size); // beam index of top_k tokens + auto tokens = std::make_unique(top_k * params_->batch_size); // token id of top_k tokens auto next_scores = std::span(scores.get(), top_k * params_->batch_size); auto next_indices = std::span(indices.get(), top_k * params_->batch_size); auto next_tokens = std::span(tokens.get(), top_k * params_->batch_size); + // TODO(aciddelgado): Optimize this top k with partial sort for (size_t batch_index = 0; batch_index < static_cast(params_->batch_size); batch_index++) { std::priority_queue> queue; auto token_scores_sub = next_token_scores_.subspan(batch_index * params_->search.num_beams * params_->vocab_size, static_cast(params_->search.num_beams) * params_->vocab_size); @@ -114,9 +121,9 @@ void BeamSearch_Cpu::SelectTop() { } #if 0 - DumpMemory("Next Scores", next_scores); - DumpMemory("Next Tokens", next_tokens); - DumpMemory("Next Indices", next_indices); + DumpSpan(std::cout, next_tokens); + DumpSpan(std::cout, next_indices_); + DumpSpan(std::cout, next_scores_); #endif beam_scorer_->Process(sequences_, next_scores, next_tokens, next_indices); @@ -140,19 +147,6 @@ void GreedySearch_Cpu::SelectTop() { AppendNextTokensToSequences(); } -void SoftMax(std::span scores, float temperature) { - float const max_score = *std::max_element(scores.begin(), scores.end()); - - // Subtract max score and scale by temperature - std::transform(scores.begin(), scores.end(), scores.begin(), [max_score, temperature](float score) { return std::exp((score - max_score) / temperature); }); - - // Compute sum of exponentials - float const exp_sum = std::accumulate(scores.begin(), scores.end(), 0.0f); - - // Divide each score by the sum of exponentials - std::transform(scores.begin(), scores.end(), scores.begin(), [exp_sum](float score) { return score / exp_sum; }); -} - void GreedySearch_Cpu::SampleTopK(int k, float temperature) { for (size_t batch_id = 0; batch_id < params_->batch_size; batch_id++) { std::span const scores = next_token_scores_.subspan(batch_id * params_->vocab_size, params_->vocab_size); @@ -258,6 +252,15 @@ void GreedySearch_Cpu::AppendNextTokensToSequences() { } } +bool BeamSearch_Cpu::IsDone() const { + if (beam_scorer_->IsDone()) { + return true; + } else if (sequences_.GetSequenceLength() == params_->search.max_length) { + return true; + } + return false; +} + void BeamSearch_Cpu::AppendNextTokensToSequences() { sequences_.AppendNextTokenToSequences(beam_scorer_->GetNextIndicesCPU(), beam_scorer_->GetNextTokens()); @@ -268,8 +271,26 @@ void BeamSearch_Cpu::AppendNextTokensToSequences() { } } -void BeamSearch_Cpu::Finalize(size_t num_return_sequences, RoamingArray output, RoamingArray sequence_scores) { - beam_scorer_->Finalize(sequences_, num_return_sequences, output, sequence_scores); +void BeamSearch_Cpu::Finalize(size_t num_return_sequences) { + if (finalized_) + return; + beam_scorer_->Finalize(sequences_, num_return_sequences); + finalized_ = true; +} + +RoamingArray BeamSearch_Cpu::GetSequence(size_t index) { + size_t batch_id = index / params_->search.num_return_sequences; + size_t beam_id = index % params_->search.num_return_sequences; + Finalize(params_->search.num_return_sequences); + BeamHypotheses beam_hyp = beam_scorer_->GetBeamHypotheses(batch_id); + return beam_hyp.GetHypothesis(beam_id); +} + +// TODO(aciddelgado): my question is, should this return copy or reference? +RoamingArray BeamSearch_Cpu::GetSequence(size_t batch_id, size_t beam_id) { + Finalize(params_->search.num_return_sequences); + BeamHypotheses beam_hyp = beam_scorer_->GetBeamHypotheses(batch_id); + return beam_hyp.GetHypothesis(beam_id); } std::span Search_Cpu::GetScores(int batch_beam_index) const { diff --git a/src/search.h b/src/search.h index 2c307bee7..901cb437f 100644 --- a/src/search.h +++ b/src/search.h @@ -1,10 +1,10 @@ #include "sequences.h" #include +#include "beam_search_scorer.h" +#pragma once namespace Generators { -struct BeamSearchScorer; - struct Search { Search(const GeneratorParams& params) : params_{params.shared_from_this()} {} virtual ~Search() = default; @@ -18,9 +18,6 @@ struct Search { virtual void SetLogits(RoamingArray logits) = 0; virtual bool IsDone() const = 0; - // TODO: Beam Search only, this should be removed and made automatic - virtual void Finalize(size_t /*num_return_sequences*/, RoamingArray /*output*/, RoamingArray /*sequence_scores*/) { assert(false); } - virtual void SelectTop() = 0; virtual void SampleTopP(float /*p*/, float /*temperature*/) { assert(false); } virtual void SampleTopK(int /*k*/, float /*temperature*/) { assert(false); } @@ -92,13 +89,19 @@ struct BeamSearch_Cpu : Search_Cpu { RoamingArray GetNextTokens() override; RoamingArray GetNextIndices() override; + // In Beam Search there are batch_size * num_beams sequences. Index is batch_id * num_beams + beam_id... Easier to use the other version. + RoamingArray GetSequence(size_t index) override; + RoamingArray GetSequence(size_t batch_id, size_t beam_id); - void SelectTop() override; + bool IsDone() const; - void Finalize(size_t num_return_sequences, RoamingArray output, RoamingArray sequence_scores) override; + void SelectTop() override; private: void AppendNextTokensToSequences(); + void Finalize(size_t num_return_sequences); + + bool finalized_{}; // To avoid calling Finalize multiple times std::unique_ptr beam_scorer_; }; diff --git a/src/search_cuda.cpp b/src/search_cuda.cpp index 43d7d3e17..160d102b1 100644 --- a/src/search_cuda.cpp +++ b/src/search_cuda.cpp @@ -51,6 +51,7 @@ BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params) topk_next_tokens_ = CudaMallocArray(2 * batch_beam_size); topk_next_indices_ = CudaMallocArray(2 * batch_beam_size); topk_next_scores_ = CudaMallocArray(2 * batch_beam_size); + softmax_buffer_ = CudaMallocArray(batch_beam_size * params_->vocab_size); constexpr size_t max_parts_of_vocab = 128; size_t topk_buffer_size = batch_beam_size * (max_parts_of_vocab + 1) * params_->search.num_beams * 2 * 2; @@ -83,15 +84,21 @@ int Search_Cuda::GetSequenceLength() const { } void BeamSearch_Cuda::SelectTop() { + cuda::DispatchBlockwiseSoftmaxForward(const_cast(¶ms_->cuda_stream), softmax_buffer_.get(), next_token_scores_.data(), params_->vocab_size, + params_->vocab_size, params_->vocab_size, params_->BatchBeamSize()); + + // Copy next_token_scores to CPU + auto next_token_scores_cpu = CudaMallocHostArray(params_->BatchBeamSize() * params_->vocab_size); + cudaMemcpyAsync(next_token_scores_cpu.get(), softmax_buffer_.get(), params_->BatchBeamSize() * params_->vocab_size * sizeof(float), cudaMemcpyDeviceToHost, params_->cuda_stream); + CudaCheck() == cudaStreamSynchronize(params_->cuda_stream); + auto beam_scores = beam_scorer_->GetNextScores(); // Add beam score to next token scores. Corresponding python code is like: // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) - cuda::LaunchAddProbsKernel(next_token_scores_.data(), beam_scores.data(), + cuda::LaunchAddProbsKernel(softmax_buffer_.get(), beam_scores.data(), params_->batch_size, params_->search.num_beams, params_->vocab_size, params_->cuda_stream); - // TODO: Write output scores? - if (params_->search.num_beams <= 32) { constexpr size_t max_parts_of_vocab = 128; size_t candidate_count = params_->BatchBeamSize() * 2 * params_->search.num_beams; @@ -101,7 +108,7 @@ void BeamSearch_Cuda::SelectTop() { float* topk_scores_2nd_stage = reinterpret_cast(topk_tokens_1st_stage + candidate_count * max_parts_of_vocab); int32_t* topk_tokens_2nd_stage = reinterpret_cast(topk_scores_2nd_stage + candidate_count); - cuda::BeamSearchTopK(next_token_scores_.data(), + cuda::BeamSearchTopK(softmax_buffer_.get(), params_->batch_size, params_->search.num_beams, params_->vocab_size, @@ -125,14 +132,15 @@ void BeamSearch_Cuda::SelectTop() { std::span next_indices{topk_next_indices_.get(), size}; #if 0 - DumpCudaMemory("Next Scores", next_scores); - DumpCudaMemory("Next Tokens", next_tokens); - DumpCudaMemory("Next Indices", next_indices); + DumpCudaSpan(std::cout, next_scores); + DumpCudaSpan(std::cout, next_tokens); + DumpCudaSpan(std::cout, next_indices); #endif beam_scorer_->Process(sequences_, next_scores, next_tokens, next_indices); next_tokens_ = beam_scorer_->GetNextTokens(); + // TODO(aciddelgado): do we need to keep track of sequences both here and in beam hypotheses? AppendNextTokensToSequences(); } @@ -184,7 +192,6 @@ void GreedySearch_Cuda::AppendNextTokensToSequences() { } bool BeamSearch_Cuda::IsDone() const { - beam_scorer_->IsDone(); if (beam_scorer_->IsDoneLater()) return true; @@ -200,8 +207,23 @@ void BeamSearch_Cuda::AppendNextTokensToSequences() { sequences_.AfterDeviceAppendedNextToken(); } -void BeamSearch_Cuda::Finalize(size_t num_return_sequences, RoamingArray output, RoamingArray sequence_scores) { - beam_scorer_->Finalize(sequences_, num_return_sequences, output.GetGPU(), sequence_scores.GetGPU()); +void BeamSearch_Cuda::Finalize(size_t num_return_sequences) { + if (finalized_) + return; + beam_scorer_->Finalize(sequences_, num_return_sequences); + finalized_ = true; +} + +RoamingArray BeamSearch_Cuda::GetSequence(size_t index) { + Finalize(params_->search.num_return_sequences); + const size_t batch_id = index / params_->search.num_return_sequences; + const size_t beam_id = index % params_->search.num_return_sequences; + return beam_scorer_->GetBeamHypothesis(batch_id, beam_id); +} + +RoamingArray BeamSearch_Cuda::GetSequence(size_t batch_id, size_t beam_id) { + Finalize(params_->search.num_return_sequences); + return beam_scorer_->GetBeamHypothesis(batch_id, beam_id); } #if 0 diff --git a/src/search_cuda.h b/src/search_cuda.h index 05bf298bf..8a699b880 100644 --- a/src/search_cuda.h +++ b/src/search_cuda.h @@ -68,20 +68,26 @@ struct BeamSearch_Cuda : Search_Cuda { RoamingArray GetNextTokens() override; RoamingArray GetNextIndices() override; + // In Beam Search there are batch_size * num_beams sequences. Index is batch_id * num_beams + beam_id... Easier to use the other version. + RoamingArray GetSequence(size_t index) override; + RoamingArray GetSequence(size_t batch_id, size_t beam_id); void SelectTop() override; - void Finalize(size_t num_return_sequences, RoamingArray output, RoamingArray sequence_scores) override; bool IsDone() const; private: void AppendNextTokensToSequences(); + void Finalize(size_t num_return_sequences); + + bool finalized_{}; // To avoid calling Finalize multiple times std::unique_ptr beam_scorer_; cuda_unique_ptr topk_next_tokens_; cuda_unique_ptr topk_next_indices_; cuda_unique_ptr topk_next_scores_; + cuda_unique_ptr softmax_buffer_; // temp buffer for topk computation, including: // 1st stage needs: diff --git a/src/smartptrs.h b/src/smartptrs.h index 5591cfde5..257ae5535 100644 --- a/src/smartptrs.h +++ b/src/smartptrs.h @@ -30,6 +30,12 @@ void copy(std::span source, std::span dest) { std::copy(source.begin(), source.end(), dest.begin()); } +template +void copy(cpu_span source, cpu_span dest) { + assert(source.size() == dest.size()); + std::copy(source.begin(), source.end(), dest.begin()); +} + template std::unique_ptr AllocateArray(size_t count, std::span* p_span = nullptr) { T* p = new T[count]; diff --git a/src/softmax.h b/src/softmax.h index 0b59c7c28..c4dbfea98 100644 --- a/src/softmax.h +++ b/src/softmax.h @@ -2,7 +2,30 @@ namespace Generators { -void softmax(std::span values); -void log_softmax(std::span values); +void SoftMax(std::span scores, float temperature) { + float const max_score = *std::max_element(scores.begin(), scores.end()); + + // Subtract max score and scale by temperature + std::transform(scores.begin(), scores.end(), scores.begin(), [max_score, temperature](float score) { return std::exp((score - max_score) / temperature); }); + + // Compute sum of exponentials + float const exp_sum = std::accumulate(scores.begin(), scores.end(), 0.0f); + + // Divide each score by the sum of exponentials + std::transform(scores.begin(), scores.end(), scores.begin(), [exp_sum](float score) { return score / exp_sum; }); +} + +void LogSoftMax(std::span scores, float temperature) { + float const max_score = *std::max_element(scores.begin(), scores.end()); + + // Subtract max score and scale by temperature + std::transform(scores.begin(), scores.end(), scores.begin(), [max_score, temperature](float score) { return (score - max_score) / temperature; }); + + // Compute sum of exponentials + float const exp_sum = std::accumulate(scores.begin(), scores.end(), 0.0f, [](float a, float b) { return a + std::exp(b); }); + + // Subtract log of sum of exponentials from each score + std::transform(scores.begin(), scores.end(), scores.begin(), [exp_sum](float score) { return score - std::log(exp_sum); }); +} } // namespace Generators diff --git a/test/model_tests.cpp b/test/model_tests.cpp index 73c6464e0..6766fb892 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -83,26 +83,14 @@ TEST(ModelTests, BeamSearchGptFp32) { params->search.length_penalty = 1.0f; params->search.num_beams = 4; - Generators::BeamSearch_Cpu search{*params}; - auto state = model->CreateState(search.sequence_lengths_, *params); - - while (!search.IsDone()) { - search.SetLogits(state->Run(search.GetSequenceLength(), search.GetNextTokens(), search.GetNextIndices())); - - // Scoring - search.ApplyMinLength(1); - search.ApplyRepetitionPenalty(1.0f); - - search.SelectTop(); - } + auto generator = Generators::CreateGenerator(*model, *params); + auto result = Generators::Generate(*model, *params); - std::vector output_sequence(static_cast(search.params_->batch_size) * search.params_->search.max_length); - search.Finalize(1, Generators::cpu_span{output_sequence}, {}); // Verify outputs match expected outputs - for (size_t i = 0; i < static_cast(search.params_->batch_size); i++) { - auto sequence = std::span(output_sequence.data() + search.params_->search.max_length * i, search.params_->search.max_length); - auto* expected_output_start = &expected_output[i * search.params_->search.max_length]; + for (int i = 0; i < params->batch_size; i++) { + auto sequence = std::span(result[i].data(), params->search.max_length); + auto* expected_output_start = &expected_output[static_cast(i) * params->search.max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), params->search.max_length * sizeof(int32_t))); } } @@ -174,24 +162,13 @@ void Test_BeamSearch_Gpt_Cuda(const char* model_path, const char* model_label) { params->search.length_penalty = 1.0f; auto generator = Generators::CreateGenerator(*model, *params); + auto result = Generators::Generate(*model, *params); - while (!generator->IsDone()) { - generator->ComputeLogits(); - generator->GenerateNextToken(); - } - - size_t sequence_length = params->batch_size * params->search.max_length; - auto output_sequence_cuda = Generators::CudaMallocArray(sequence_length); - auto output_sequence_cpu = std::make_unique(sequence_length); - - generator->search_->Finalize(1, Generators::gpu_span(output_sequence_cuda.get(), sequence_length), {}); - cudaMemcpyAsync(output_sequence_cpu.get(), output_sequence_cuda.get(), sequence_length * sizeof(int32_t), cudaMemcpyDeviceToHost, params->cuda_stream); - cudaStreamSynchronize(params->cuda_stream); // Verify outputs match expected outputs for (int i = 0; i < params->batch_size; i++) { - auto sequence = std::span(output_sequence_cpu.get() + params->search.max_length * i, params->search.max_length); - auto* expected_output_start = &expected_output[i * params->search.max_length]; + auto sequence = std::span(result[i].data(), params->search.max_length); + auto* expected_output_start = &expected_output[static_cast(i) * params->search.max_length]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence.data(), params->search.max_length * sizeof(int32_t))); } }