Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into ryanunderhill/tokenizer
  • Loading branch information
RyanUnderhill committed Jan 24, 2024
2 parents ea15429 + 3f66479 commit 7d374b5
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 153 deletions.
12 changes: 9 additions & 3 deletions src/.clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
# -readability-avoid-unconditional-preprocessor-if: for #if 0
# -cppcoreguidelines-pro-type-vararg: printf
# -misc-include-cleaner: suppress errors like: "no header providing std::span is directly".
# -cppcoreguidelines-pro-bounds-pointer-arithmetic: sometimes we use raw arrays
# -cppcoreguidelines-pro-bounds-pointer-arithmetic: sometimes we use raw arrays.
# -cppcoreguidelines-avoid-non-const-global-variables: ort_env is global.
# -bugprone-easily-swappable-parameters: hard to fix
# -performance-avoid-endl: not a big deal
# -bugprone-easily-swappable-parameters: hard to fix.
# -performance-avoid-endl: not a big deal.
# -cppcoreguidelines-slicing: some slicings are intentional.
# -cppcoreguidelines-avoid-c-arrays,-modernize-avoid-c-arrays,-readability-function-cognitive-complexity: will fix them later
Checks: >
-*,
cppcoreguidelines-*,
Expand All @@ -38,6 +40,9 @@ Checks: >
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
-cppcoreguidelines-avoid-magic-numbers,
-cppcoreguidelines-avoid-non-const-global-variables,
-cppcoreguidelines-avoid-c-arrays,
-cppcoreguidelines-slicing,
-modernize-avoid-c-arrays,
-google-readability-todo,
-google-runtime-references,
-modernize-concat-nested-namespaces,
Expand All @@ -46,6 +51,7 @@ Checks: >
-readability-uppercase-literal-suffix,
-readability-avoid-unconditional-preprocessor-if,
-readability-magic-numbers,
-readability-function-cognitive-complexity,
-bugprone-easily-swappable-parameters
WarningsAsErrors: ""
Expand Down
67 changes: 39 additions & 28 deletions src/beam_search_scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "search.h"
#include "beam_search_scorer.h"

#include <cmath>

namespace Generators {

void BeamHypotheses::Init(float length_penalty, std::span<HypothesisScore> beams) {
Expand All @@ -16,46 +18,50 @@ void BeamHypotheses::Init(float length_penalty, std::span<HypothesisScore> beams

void BeamHypotheses::Add(std::span<const int32_t> hypothesis, float sum_logprobs) {
auto length = hypothesis.size();
float score = sum_logprobs / pow(static_cast<float>(length), length_penalty_);
float const score = sum_logprobs / std::pow(static_cast<float>(length), length_penalty_);

size_t index = beams_used_;
// If the array is full, don't add unless it's better than the worst element
if (index == beams_.size()) {
if (score <= beams_[--index].score)
if (score <= beams_[--index].score) {
return;
} else
}
} else {
beams_used_++;
}

// Rotate existing elements over while the new element scores higher
for (; index > 0 && score > beams_[index - 1].score; index--)
for (; index > 0 && score > beams_[index - 1].score; index--) {
beams_[index] = beams_[index - 1];
}

beams_[index] = HypothesisScore{hypothesis, score};
}

bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) const {
float current_score = best_sum_logprobs / pow(static_cast<float>(current_length), length_penalty_);
float const current_score = best_sum_logprobs / std::pow(static_cast<float>(current_length), length_penalty_);
return beams_.back().score < current_score;
}

void BeamHypotheses::Output(
size_t top_k,
size_t max_length,
std::span<int32_t> sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
std::span<float> sequences_scores) // buffer of shape (num_return_sequences) or empty
std::span<int32_t> sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
std::span<float> sequences_scores) const // buffer of shape (num_return_sequences) or empty
{
// Copy the top_k beams into the sequences
assert(top_k <= beams_used_);
for (int index = 0; index < top_k; index++) {
auto& item = beams_[index];
std::span<int32_t> target = sequences.subspan(index * max_length, max_length);
std::span<int32_t> const target = sequences.subspan(index * max_length, max_length);

// Note that word_ids might be less than max_length.
// Since the sequences has been filled with pad token ID, so padding is not needed here.
copy(item.hypothesis, target);

if (!sequences_scores.empty())
if (!sequences_scores.empty()) {
sequences_scores[index] = item.score;
}
}
}

Expand All @@ -67,27 +73,28 @@ BeamSearchScorer::BeamSearchScorer(const SearchParams& parameters)
eos_token_id_{parameters.eos_token_id},
early_stopping_{parameters.early_stopping},
not_done_count_{parameters.batch_size} {
size_t batch_beam_size = batch_size_ * num_beams_;
size_t const batch_beam_size = static_cast<size_t>(batch_size_) * num_beams_;

std::span<HypothesisScore> beams;
hypothesis_scores_ptr_ = AllocateArray<HypothesisScore>(batch_beam_size, &beams);
beam_hyps_ptr_ = AllocateArray<BeamHypotheses>(batch_size_, &beam_hyps_);
for (size_t i = 0; i < batch_size_; i++)
for (size_t i = 0; i < batch_size_; i++) {
beam_hyps_[i].Init(parameters.length_penalty, beams.subspan(i * num_beams_, num_beams_));
}

next_beam_scores_ptr_ = AllocateArray<float>(batch_beam_size, &next_beam_scores_);
next_beam_tokens_ptr_ = AllocateArray<int32_t>(batch_beam_size, &next_beam_tokens_);
next_beam_indices_ptr_ = AllocateArray<int32_t>(batch_beam_size, &next_beam_indices_);

// Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length.
size_t per_beam = (max_length_ * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2;
size_t const per_beam = (max_length_ * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2;
hypothesis_buffer_ptr_ = AllocateArray<int32_t>(batch_beam_size * per_beam, &hypothesis_buffer_);

memset(next_beam_scores_.data(), 0, next_beam_scores_.size_bytes());

// Initialize score of first beam of each group with 0 and the rest with -1e9.
// This ensures that the beams in the same group don't produce same tokens every time.
std::span<float> beam_scores = next_beam_scores_;
std::span<float> const beam_scores = next_beam_scores_;
for (int i = 0; i < parameters.batch_size; i++) {
for (int j = 1; j < parameters.num_beams; j++) {
beam_scores[i * parameters.num_beams + j] = -1e9;
Expand Down Expand Up @@ -124,22 +131,22 @@ void BeamSearchScorer::Process(Sequences& sequences,

// Next tokens for this sentence.
size_t beam_idx = 0;
size_t top_k = 2 * num_beams_;
size_t const top_k = 2 * num_beams_;
for (size_t j = 0; j < top_k; j++) {
int32_t next_token = next_tokens[batch * top_k + j];
float next_score = next_scores[batch * top_k + j];
int32_t next_index = next_indices[batch * top_k + j];
int32_t const next_token = next_tokens[batch * top_k + j];
float const next_score = next_scores[batch * top_k + j];
int32_t const next_index = next_indices[batch * top_k + j];

int batch_beam_idx = static_cast<int>(batch * num_beams_) + next_index;
int const batch_beam_idx = static_cast<int>(batch * num_beams_) + next_index;
// Add to generated hypotheses if end of sentence.
if ((eos_token_id_ >= 0) && (next_token == eos_token_id_)) {
bool is_beam_token_worse_than_top_num_beams = (j >= num_beams_);
bool const is_beam_token_worse_than_top_num_beams = (j >= num_beams_);
if (is_beam_token_worse_than_top_num_beams) {
continue;
}

// Clone the sequence and append to buffer.
std::span<const int32_t> src = sequences.GetSequence(batch_beam_idx);
std::span<const int32_t> const src = sequences.GetSequence(batch_beam_idx);
auto clone = hypothesis_buffer_.subspan(static_cast<size_t>(hypothesis_buffer_used_), sequence_length);

copy(src, clone);
Expand All @@ -154,22 +161,25 @@ void BeamSearchScorer::Process(Sequences& sequences,
}

// Once the beam for next step is full, don't add more tokens to it.
if (beam_idx == num_beams_)
if (beam_idx == num_beams_) {
break;
}
}

assert(beam_idx == num_beams_);
assert(static_cast<size_t>(hypothesis_buffer_used_) <= hypothesis_buffer_.size());

// Check if we are done so that we can save a pad step if all(done)
if (static_cast<size_t>(beam_hyp.beams_used_) < num_beams_)
if (static_cast<size_t>(beam_hyp.beams_used_) < num_beams_) {
continue;
}

if (!early_stopping_) {
std::span<const float> topk_scores = next_scores.subspan(batch * num_beams_, top_k);
std::span<const float> const topk_scores = next_scores.subspan(batch * num_beams_, top_k);
const auto best_sum_logprobs = std::max_element(topk_scores.begin(), topk_scores.end());
if (beam_hyp.CanImprove(*best_sum_logprobs, sequence_length))
if (beam_hyp.CanImprove(*best_sum_logprobs, sequence_length)) {
continue;
}
}

beam_hyp.done_ = true;
Expand All @@ -192,8 +202,8 @@ void BeamSearchScorer::Finalize(Sequences& sequences,
}

for (int beam_index = 0; beam_index < num_beams_; beam_index++) {
int batch_beam_index = batch_index * num_beams_ + beam_index;
float final_score = next_beam_scores_[batch_beam_index];
int const batch_beam_index = batch_index * num_beams_ + beam_index;
float const final_score = next_beam_scores_[batch_beam_index];
auto final_tokens = sequences.GetSequence(batch_beam_index);
beam_hyp.Add(final_tokens, final_score);
}
Expand All @@ -209,11 +219,12 @@ void BeamSearchScorer::Finalize(Sequences& sequences,
auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_,
num_return_sequences * max_length_);
std::span<float> sequence_scores_buffer;
if (!sequence_scores.empty())
if (!sequence_scores.empty()) {
sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences);
}

beam_hyp.Output(num_return_sequences, max_length_, batch_output, sequence_scores_buffer);
}
}

} // namespace Generators
} // namespace Generators
8 changes: 4 additions & 4 deletions src/beam_search_scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ struct BeamHypotheses {
bool CanImprove(float best_sum_logprobs, int current_length) const;

// Output results
void Output(size_t top_k, // number of sequences to return
size_t max_length, // max sequence length
std::span<int32_t> sequences, // buffer with pad token, shape (num_return_sequences, max_length)
std::span<float> sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
void Output(size_t top_k, // number of sequences to return
size_t max_length, // max sequence length
std::span<int32_t> sequences, // buffer with pad token, shape (num_return_sequences, max_length)
std::span<float> sequences_scores) const; // buffer for sequence scores, with shape (num_return_sequences)

std::span<HypothesisScore> beams_; // Beam width sized array of hypotheses, sorted by highest scoring
int beams_used_; // Number of elements used in beams_
Expand Down
Loading

0 comments on commit 7d374b5

Please sign in to comment.