Skip to content

Commit

Permalink
Zhalei/fix seqoutput type (#18765)
Browse files Browse the repository at this point in the history
After refactoring beamsearch, all scores become fp32. Yet it need
support fp16 according to original specs.
  • Loading branch information
zhanghuanrong authored Jan 22, 2024
1 parent 21034a2 commit 373ebac
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,8 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
output_sequences_scores);

// Output per token scores
if (output_scores) {
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
gsl::span<const float> source = beam_state.scores;
assert(target.size() == source.size());
ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
}
gsl::span<const float> per_token_scores = beam_state.scores;
this->beam_scorer_->OutputScores(per_token_scores, output_scores);

return status;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,8 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
output_sequences_scores);

// Output per token scores
if (output_scores) {
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
gsl::span<const float> source = beam_state.scores;
assert(target.size() == source.size());
ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
}
gsl::span<const float> per_token_scores = beam_state.scores;
this->beam_scorer_->OutputScores(per_token_scores, output_scores);

return status;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,8 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
output_sequences_scores);

// Output per token scores
if (output_scores) {
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
gsl::span<const float> source = beam_state.scores;
assert(target.size() == source.size());
ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
}
gsl::span<const float> per_token_scores = beam_state.scores;
this->beam_scorer_->OutputScores(per_token_scores, output_scores);

return status;
}
Expand Down
82 changes: 58 additions & 24 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con
return beams_.back().score < current_score;
}

template <typename T>
void BeamHypotheses::Output(
int top_k,
int max_length,
gsl::span<int32_t>& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
gsl::span<float>& sequences_scores) // buffer of shape (num_return_sequences) or empty
gsl::span<int32_t>& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
gsl::span<T>& sequences_scores) // buffer of shape (num_return_sequences) or empty
{
// Copy the top_k beams into the sequences
ORT_ENFORCE(top_k <= beams_used_);
Expand All @@ -67,7 +68,7 @@ void BeamHypotheses::Output(
gsl::copy(item.hypothesis, target);

if (!sequences_scores.empty())
sequences_scores[index] = item.score;
sequences_scores[index] = (T)item.score;
}
}

Expand Down Expand Up @@ -181,21 +182,21 @@ void BeamSearchScorer::Process(ISequences& sequences,
}
}

void BeamSearchScorer::Finalize(ISequences& sequences,
gsl::span<const float>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) {
ORT_ENFORCE(output_sequences != nullptr);

template <typename T>
void OutputSequenceScores(BeamSearchScorer* scorer,
ISequences& sequences,
gsl::span<const float>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) {
// Finalize all open beam hypotheses and add to generated hypotheses.
for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) {
BeamHypotheses& beam_hyp = beam_hyps_[batch_index];
for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) {
BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index];
if (beam_hyp.done_) {
continue;
}

for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) {
size_t batch_beam_index = batch_index * num_beams_ + beam_index;
for (size_t beam_index = 0; beam_index < scorer->num_beams_; beam_index++) {
size_t batch_beam_index = batch_index * scorer->num_beams_ + beam_index;
float final_score = final_beam_scores[batch_beam_index];
auto final_tokens = sequences.GetSequence(narrow<int>(batch_beam_index));
beam_hyp.Add(final_tokens, final_score);
Expand All @@ -206,26 +207,59 @@ void BeamSearchScorer::Finalize(ISequences& sequences,
gsl::span<int32_t> output = output_sequences->MutableDataAsSpan<int32_t>();

// Fill output sequences with pad token ID so that we do not need append it later.
std::fill_n(output.data(), output.size(), pad_token_id_);
std::fill_n(output.data(), output.size(), scorer->pad_token_id_);

// Score of each sequence, with shape (batch_size * num_return_sequences).
gsl::span<float> sequence_scores;
gsl::span<T> sequence_scores;
if (output_sequence_scores) {
sequence_scores = output_sequence_scores->MutableDataAsSpan<float>();
sequence_scores = output_sequence_scores->MutableDataAsSpan<T>();
}

// Select the best hypotheses according to number of sequences to return.
for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) {
BeamHypotheses& beam_hyp = beam_hyps_[batch_index];
for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) {
BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index];

auto batch_output = output.subspan(batch_index * num_return_sequences_ * max_length_,
num_return_sequences_ * max_length_);
gsl::span<float> sequence_scores_buffer;
auto batch_output = output.subspan(batch_index * scorer->num_return_sequences_ * scorer->max_length_,
scorer->num_return_sequences_ * scorer->max_length_);
gsl::span<T> sequence_scores_buffer;
if (!sequence_scores.empty())
sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences_, num_return_sequences_);
sequence_scores_buffer = sequence_scores.subspan(batch_index * scorer->num_return_sequences_, scorer->num_return_sequences_);

beam_hyp.template Output<T>(narrow<int>(scorer->num_return_sequences_), narrow<int>(scorer->max_length_), batch_output,
sequence_scores_buffer);
}
}

void BeamSearchScorer::Finalize(ISequences& sequences,
gsl::span<const float>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) {
ORT_ENFORCE(output_sequences != nullptr);

beam_hyp.Output(narrow<int>(num_return_sequences_), narrow<int>(max_length_), batch_output,
sequence_scores_buffer);
if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType<float>()) {
OutputSequenceScores<float>(this, sequences, final_beam_scores, output_sequences, output_sequence_scores);
} else {
ORT_ENFORCE(output_sequence_scores->IsDataType<MLFloat16>());
OutputSequenceScores<MLFloat16>(this, sequences, final_beam_scores, output_sequences, output_sequence_scores);
}
}

void BeamSearchScorer::OutputScores(gsl::span<const float>& final_scores, Tensor* output_scores) {
if (output_scores) {
if (output_scores->IsDataType<float>()) {
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
ORT_ENFORCE(target.size() == final_scores.size());
std::copy_n(final_scores.data(), final_scores.size(), target.data());
} else {
ORT_ENFORCE(output_scores->IsDataType<MLFloat16>());
gsl::span<MLFloat16> target = output_scores->MutableDataAsSpan<MLFloat16>();
ORT_ENFORCE(target.size() == final_scores.size());
const float* src = final_scores.data();
MLFloat16* dst = target.data();
for (size_t i = 0; i < target.size(); i++) {
dst[i] = MLFloat16(src[i]);
}
}
}
}

Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ struct BeamHypotheses {
bool CanImprove(float best_sum_logprobs, int current_length) const;

// Output results
void Output(int top_k, // number of sequences to return
int max_length, // max sequence length
gsl::span<int32_t>& sequences, // buffer with pad token, shape (num_return_sequences, max_length)
gsl::span<float>& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
template <typename T>
void Output(int top_k, // number of sequences to return
int max_length, // max sequence length
gsl::span<int32_t>& sequences, // buffer with pad token, shape (num_return_sequences, max_length)
gsl::span<T>& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)

gsl::span<HypothesisScore> beams_; // Beam width sized array of hypotheses, sorted by highest scoring
int beams_used_; // Number of elements used in beams_
Expand All @@ -60,13 +61,14 @@ struct BeamSearchScorer : IBeamScorer {
Tensor* output_sequences,
Tensor* output_sequence_scores) override;

void OutputScores(gsl::span<const float>& final_scores, Tensor* output_scores) override;

bool IsDone() const override { return not_done_count_ == 0; }

gsl::span<float> GetNextScores() override { return next_beam_scores_; }
gsl::span<int32_t> GetNextTokens() override { return next_beam_tokens_; }
gsl::span<int32_t> GetNextIndicesCPU() override { return next_beam_indices_; }

private:
size_t batch_size_;
size_t num_beams_;
size_t max_length_;
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ struct IBeamScorer {
Tensor* output_sequences,
Tensor* output_sequence_scores) = 0;

virtual void OutputScores(gsl::span<const float>& final_scores,
Tensor* output_scores) = 0;

virtual bool IsDone() const = 0; // GPU version will return false here, as it asynchronously queues up the event
virtual bool IsDoneLater() const { return false; } // GPU version waits for the asynchous result to complete here

Expand Down
63 changes: 59 additions & 4 deletions onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,13 @@ __device__ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_
return beams_[beams_count_ - 1].score < current_score;
}

template <typename T>
__device__ void BeamHypotheses::Output(
int top_k,
int max_length,
int pad_token_id,
int32_t* sequences, // buffer of shape (num_return_sequences, max_length)
float* sequences_scores) // buffer of shape (num_return_sequences) or empty
T* sequences_scores) // buffer of shape (num_return_sequences) or empty
{
// Copy the top_k beams into the sequences
for (int index = 0; index < top_k; index++) {
Expand All @@ -327,7 +328,7 @@ __device__ void BeamHypotheses::Output(
target[i] = pad_token_id;

if (sequences_scores)
sequences_scores[index] = item.score;
sequences_scores[index] = (T)item.score;
}
}

Expand Down Expand Up @@ -501,13 +502,14 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp
next_beam_tokens.data());
}

template <typename T>
__global__ void BeamSearchScorer_Finalize(BeamScorerState& state,
const int32_t* sequences_buffer,
int sequence_length,
BeamHypotheses* beam_hyps_,
const float* final_beam_scores,
int32_t* output,
float* sequence_scores) {
T* sequence_scores) {
int batch_index = blockIdx.x * blockDim.x + threadIdx.x;
if (batch_index >= state.batch_size_)
return;
Expand All @@ -534,14 +536,15 @@ __global__ void BeamSearchScorer_Finalize(BeamScorerState& state,
sequence_scores ? sequence_scores + batch_index * state.num_return_sequences_ : nullptr);
}

template <typename T>
void LaunchBeamSearchScorer_Finalize(int batch_size,
BeamScorerState& state,
gsl::span<const int32_t> sequences,
int sequence_length,
gsl::span<BeamHypotheses> beam_hyps,
gsl::span<const float> final_beam_scores,
gsl::span<int32_t> output,
gsl::span<float> sequence_scores,
gsl::span<T> sequence_scores,
cudaStream_t stream) {
BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state,
sequences.data(),
Expand All @@ -552,6 +555,58 @@ void LaunchBeamSearchScorer_Finalize(int batch_size,
sequence_scores.data());
}

template void LaunchBeamSearchScorer_Finalize<float>(
int batch_size,
BeamScorerState& state,
gsl::span<const int32_t> sequences,
int sequence_length,
gsl::span<BeamHypotheses> beam_hyps,
gsl::span<const float> final_beam_scores,
gsl::span<int32_t> output,
gsl::span<float> sequence_scores,
cudaStream_t stream);

template void LaunchBeamSearchScorer_Finalize<__half>(
int batch_size,
BeamScorerState& state,
gsl::span<const int32_t> sequences,
int sequence_length,
gsl::span<BeamHypotheses> beam_hyps,
gsl::span<const float> final_beam_scores,
gsl::span<int32_t> output,
gsl::span<__half> sequence_scores,
cudaStream_t stream);

template <typename T>
__global__ void FloatConvertAndCopyKernel(const float* src, T* dst, size_t total_elements) {
int64_t index = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
if (index < total_elements) {
dst[index] = (T)src[index];
}
}

template <typename T>
void LaunchBeamSearchScoreCopy(gsl::span<const float> final_scores,
gsl::span<T> output_scores,
cudaStream_t stream) {
ORT_ENFORCE(final_scores.size() == output_scores.size());
constexpr unsigned ThreadPerBlock = 256;
unsigned num_blocks = (unsigned)((final_scores.size() + (ThreadPerBlock - 1))/ ThreadPerBlock);

typedef typename ToCudaType<float>::MappedType CudaT;

FloatConvertAndCopyKernel<<<num_blocks, ThreadPerBlock, 0, stream>>>(
final_scores.data(), (CudaT*)output_scores.data(), final_scores.size());
}

template void LaunchBeamSearchScoreCopy(gsl::span<const float> final_scores,
gsl::span<float> output_scores,
cudaStream_t stream);

template void LaunchBeamSearchScoreCopy(gsl::span<const float> final_scores,
gsl::span<MLFloat16> output_scores,
cudaStream_t stream);

__global__ void AddProbsKernel(float* log_probs,
float* cum_log_probs,
const int vocab_size,
Expand Down
19 changes: 13 additions & 6 deletions onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ struct BeamHypotheses {
__device__ bool CanImprove(float best_sum_logprobs, int current_length) const;

// Output results
__device__ void Output(int top_k, // number of sequences to return
int max_length, // max sequence length
int pad_token_id, // pad token
int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length)
float* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
template <typename T>
__device__ void Output(int top_k, // number of sequences to return
int max_length, // max sequence length
int pad_token_id, // pad token
int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length)
T* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
};

struct BeamScorerState {
Expand Down Expand Up @@ -110,16 +111,22 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp
gsl::span<int32_t> next_beam_indices,
cudaStream_t stream);

template <typename T>
void LaunchBeamSearchScorer_Finalize(int batch_size,
BeamScorerState& state,
gsl::span<const int32_t> sequences,
int sequence_length,
gsl::span<BeamHypotheses> beam_hyps_,
gsl::span<const float> final_beam_scores,
gsl::span<int32_t> output,
gsl::span<float> sequence_scores,
gsl::span<T> sequence_scores,
cudaStream_t stream);

template <typename T>
void LaunchBeamSearchScoreCopy(gsl::span<const float> final_scores,
gsl::span<T> output_scores,
cudaStream_t stream);

void LaunchNextTokenKernel(const int64_t* next_token_indices,
int32_t* next_indices,
int32_t* next_tokens,
Expand Down
Loading

0 comments on commit 373ebac

Please sign in to comment.