Skip to content

Commit

Permalink
Merge Search & State into new "Generator" type (#39)
Browse files Browse the repository at this point in the history
* Add new 'Generator' object to merge State & Search types together as one for simplicity.
* Bring back simple 'Generate' method for python
* Rename SearchParams to GeneratorParams
  • Loading branch information
RyanUnderhill authored Jan 29, 2024
1 parent 59127f9 commit 06c294a
Show file tree
Hide file tree
Showing 26 changed files with 278 additions and 232 deletions.
27 changes: 12 additions & 15 deletions examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,45 @@
# model=og.Model("../../test_models/llama2-7b-fp32-cpu", device_type)
#model=og.Llama_Model("../../test_models/llama2-7b-fp16-gpu/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx", device_type)
#model=og.Llama_Model("../../test_models/llama2-7b-int4-gpu/rank_0_Llama-2-7b-hf_decoder_merged_model_int4.onnx", device_type)
model=og.Model("../../test_models/llama2-7b-chat-int4-gpu", device_type)
model=og.Model("../test_models/llama2-7b-chat-int4-gpu", device_type)
print("Model loaded")
# tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
tokenizer=model.CreateTokenizer()
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
# tokenizer=model.CreateTokenizer()
print("Tokenizer created")

# Keep asking for input prompts in an loop
while True:
text = input("Input:")
# input_tokens = tokenizer.encode(text, return_tensors='np')
input_tokens = tokenizer.encode(text)
input_tokens = tokenizer.encode(text, return_tensors='np')
# input_tokens = tokenizer.encode(text)

params=og.SearchParams(model)
params.max_length = 128
params.input_ids = input_tokens

search=params.CreateSearch()
state=model.CreateState(search.GetSequenceLengths(), params)
generator=og.Generator(model, params)

print("Output:")

print(text, end='', flush=True)
while not search.IsDone():
search.SetLogits(state.Run(search.GetSequenceLength(), search.GetNextTokens()))
while not generator.IsDone():
generator.ComputeLogits()

# search.Apply_MinLength(1)
# search.Apply_RepetitionPenalty(1.0)

search.SampleTopP(0.7, 0.6)
generator.AppendNextToken_TopP(0.7, 0.6)

print(tokenizer.decode([search.GetNextTokens().GetArray()[0]]), ' ', end='', flush=True)
'''
# print(tokenizer.decode([generator.GetNextTokens().GetArray()[0]]), ' ', end='', flush=True)
# Print each token as we compute it, we have to do some work to get newlines & spaces to appear properly:
word=tokenizer.convert_ids_to_tokens([search.GetNextTokens().GetArray()[0]])[0]
word=tokenizer.convert_ids_to_tokens([generator.GetNextTokens().GetArray()[0]])[0]
if word=='<0x0A>':
word = '\n'
if word.startswith('▁'):
word = ' ' + word[1:]
print(word, end='', flush=True)
'''

# Print sequence all at once vs as it's decoded:
print(tokenizer.decode(search.GetSequence(0).GetArray()))
print(tokenizer.decode(generator.GetSequence(0).GetArray()))
print()
print()
2 changes: 1 addition & 1 deletion src/beam_search_scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void BeamHypotheses::Output(
}
}

BeamSearchScorer::BeamSearchScorer(const SearchParams& parameters)
BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters)
: batch_size_{parameters.batch_size},
num_beams_{parameters.num_beams},
max_length_{parameters.max_length},
Expand Down
2 changes: 1 addition & 1 deletion src/beam_search_scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct BeamHypotheses {
};

struct BeamSearchScorer {
BeamSearchScorer(const SearchParams& parameters);
BeamSearchScorer(const GeneratorParams& parameters);

void Process(Sequences& sequences,
std::span<const float> next_scores,
Expand Down
3 changes: 2 additions & 1 deletion src/beam_search_scorer_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#include "generators.h"
#include "search.h"
#include "search_cuda.h"
#include "beam_search_scorer_cuda.cuh"
#include "beam_search_scorer_cuda.h"

namespace Generators {

BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const SearchParams& parameters)
BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters)
: stream_{parameters.cuda_stream} {
state_cpu_ = CudaMallocHostArray<cuda::BeamScorerState>(1);
state_cpu_->batch_size_ = static_cast<size_t>(parameters.batch_size);
Expand Down
2 changes: 1 addition & 1 deletion src/beam_search_scorer_cuda.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
namespace Generators {

struct BeamSearchScorer_Cuda {
BeamSearchScorer_Cuda(const SearchParams& parameters);
BeamSearchScorer_Cuda(const GeneratorParams& parameters);

void Process(Sequences_Cuda& sequences,
std::span<const float> next_scores,
Expand Down
93 changes: 84 additions & 9 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ float Float16ToFloat32(uint16_t v) {
return std::ldexp((sign != 0 ? -1.0f : 1.0f) * (1.0f + static_cast<float>(fraction) / 1024.0f), exponent - 15);
}

SearchParams::SearchParams(const Model& model)
GeneratorParams::GeneratorParams(const Model& model)
: pad_token_id{model.config_->pad_token_id},
eos_token_id{model.config_->eos_token_id},
vocab_size{model.config_->model.vocab_size},
Expand Down Expand Up @@ -66,19 +66,94 @@ ProviderOptions GetDefaultProviderOptions([[maybe_unused]] DeviceType device_typ
return options;
}

std::unique_ptr<Search> SearchParams::CreateSearch() const {
std::unique_ptr<Generator> CreateGenerator(Model& model, const GeneratorParams& search_params) {
return std::make_unique<Generator>(model, search_params);
}

std::unique_ptr<Search> CreateSearch(const GeneratorParams& params) {
#if USE_CUDA
if (device_type == DeviceType::CUDA) {
if (num_beams > 1)
return std::make_unique<BeamSearch_Cuda>(*this);
return std::make_unique<GreedySearch_Cuda>(*this);
if (params.device_type == DeviceType::CUDA) {
if (params.num_beams > 1)
return std::make_unique<BeamSearch_Cuda>(params);
return std::make_unique<GreedySearch_Cuda>(params);
}
#endif

if (num_beams > 1) {
return std::make_unique<BeamSearch_Cpu>(*this);
if (params.num_beams > 1) {
return std::make_unique<BeamSearch_Cpu>(params);
}
return std::make_unique<GreedySearch_Cpu>(params);
}

Generator::Generator(Model& model, const GeneratorParams& search_params) : model_{model} {
search_ = CreateSearch(search_params);
state_ = model.CreateState(search_->GetSequenceLengths(), search_params);
}

void Generator::ComputeLogits() {
if (computed_logits_)
throw std::runtime_error("ComputeLogits called again without calling AppendNextToken* first");

search_->SetLogits(state_->Run(search_->GetSequenceLength(), search_->GetNextTokens(), search_->GetNextIndices()));
computed_logits_ = true;
}

bool Generator::IsDone() const {
if (computed_logits_)
throw std::runtime_error("IsDone() can't be called in the middle of processing logits");

return search_->IsDone();
}

void Generator::AppendNextToken_TopK_TopP(int top_k, float top_p, float temperature) {
if (search_->params_.num_beams != 1)
throw std::runtime_error("TopK and TopP cannot be used with a beam search");

if (!computed_logits_)
throw std::runtime_error("Must call ComputeLogits before AppendNextToken*");
computed_logits_ = false;

// TODO: Do TopK if top_k >1 then do TopP on the results
if (top_p < 1.0f) {
search_->SampleTopP(top_p, temperature);
} else if (top_k > 1) {
search_->SampleTopK(top_k, temperature);
} else {
search_->SelectTop();
}
}

void Generator::AppendNextToken() {
if (search_->params_.num_beams > 1) {
if (!computed_logits_)
throw std::runtime_error("Must call ComputeLogits before AppendNextToken*");
computed_logits_ = false;
search_->SelectTop();
return;
}

auto& config = *model_.config_;
AppendNextToken_TopK_TopP(config.top_k, config.top_p, config.temperature);
}

RoamingArray<int32_t> Generator::GetSequence(int index) {
return search_->GetSequence(index);
}

std::vector<int32_t> Generate(Model& model, const GeneratorParams& params) {
auto generator = CreateGenerator(model, params);

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->AppendNextToken();
}
return std::make_unique<GreedySearch_Cpu>(*this);

auto results = generator->search_->GetSequence(0);
auto results_cpu = results.GetCPU();

std::vector<int32_t> v;
v.assign(results_cpu.begin(), results_cpu.end());
return v;
}

} // namespace Generators
55 changes: 28 additions & 27 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

namespace Generators {
struct Model;
struct State;
struct Search;

// If we don't include cuda_runtime.h, we define this to avoid lots of extra #ifdefs
#ifndef USE_CUDA
Expand All @@ -52,33 +54,9 @@ using ProviderOptions = std::variant<

ProviderOptions GetDefaultProviderOptions(DeviceType device_type);

struct Search {
virtual ~Search() = default;

virtual RoamingArray<int32_t> GetNextTokens() = 0;
virtual RoamingArray<int32_t> GetNextIndices() {
throw std::runtime_error("GetNextIndices() can only be called for beam search, num_beams must be >1");
}
virtual RoamingArray<int32_t> GetSequenceLengths() = 0;
virtual int GetSequenceLength() const = 0;
virtual RoamingArray<int32_t> GetSequence(int index) = 0;

virtual void SetLogits(RoamingArray<float> logits) = 0;
virtual bool IsDone() const = 0;

// TODO: Beam Search only, this should be removed and made automatic
virtual void Finalize(size_t /*num_return_sequences*/, RoamingArray<int32_t> /*output*/, RoamingArray<float> /*sequence_scores*/) { assert(false); }

virtual void SelectTop() = 0;
virtual void SampleTopP(float /*p*/, float /*temperature*/) { assert(false); }
virtual void SampleTopK(int /*k*/, float /*temperature*/) { assert(false); }
};

struct SearchParams {
SearchParams() = default; // This constructor is only used if doing a custom model handler vs built-in
SearchParams(const Model& model);

std::unique_ptr<Search> CreateSearch() const;
struct GeneratorParams {
GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in
GeneratorParams(const Model& model);

// Values copied from config
int pad_token_id{};
Expand Down Expand Up @@ -124,6 +102,29 @@ struct SearchParams {
std::variant<Whisper> inputs;
};

struct Generator {
Generator(Model& model, const GeneratorParams& search_params);

bool IsDone() const;
void ComputeLogits();
void AppendNextToken_TopK_TopP(int top_k, float top_p, float temperature);
void AppendNextToken_TopP(float p, float temperature) { AppendNextToken_TopK_TopP(0, p, temperature); }
void AppendNextToken_TopK(int k, float temperature) { AppendNextToken_TopK_TopP(k, 1.0f, temperature); }
void AppendNextToken_Top() { AppendNextToken_TopK_TopP(1, 1.0f, 0.0f); }
void AppendNextToken();

RoamingArray<int32_t> GetSequence(int index);

Model& model_;
std::unique_ptr<State> state_;
std::unique_ptr<Search> search_;
bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio
};

std::unique_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, const ProviderOptions* provider_options = nullptr);
std::unique_ptr<Generator> CreateGenerator(Model& model, const GeneratorParams& search_params);
std::vector<int32_t> Generate(Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence

float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
void top_k_indices(std::span<int32_t> top_k, std::span<const float> inputs);

Expand Down
4 changes: 2 additions & 2 deletions src/models/gpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Gpt_Model::Gpt_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const Prov
InitDeviceAllocator(*session_decoder_);
}

std::unique_ptr<State> Gpt_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) {
std::unique_ptr<State> Gpt_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) {
return std::make_unique<Gpt_State>(*this, sequence_lengths, params);
}

Gpt_State::Gpt_State(Gpt_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const SearchParams& search_params)
Gpt_State::Gpt_State(Gpt_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const GeneratorParams& search_params)
: State{search_params},
model_{model},
position_ids_{model, *this, sequence_lengths_unk} {
Expand Down
4 changes: 2 additions & 2 deletions src/models/gpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace Generators {
struct Gpt_Model : Model {
Gpt_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const ProviderOptions* provider_options);

std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) override;
std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) override;

std::unique_ptr<OrtSession> session_decoder_;
};

struct Gpt_State : State {
Gpt_State(Gpt_Model& model, RoamingArray<int32_t> sequence_lengths, const SearchParams& search_params);
Gpt_State(Gpt_Model& model, RoamingArray<int32_t> sequence_lengths, const GeneratorParams& search_params);
RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) override;

private:
Expand Down
4 changes: 2 additions & 2 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Llama_Model::Llama_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const
InitDeviceAllocator(*session_decoder_);
}

std::unique_ptr<State> Llama_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) {
std::unique_ptr<State> Llama_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) {
return std::make_unique<Llama_State>(*this, sequence_lengths, params);
}

Llama_State::Llama_State(Llama_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const SearchParams& search_params)
Llama_State::Llama_State(Llama_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const GeneratorParams& search_params)
: State{search_params},
model_{model},
position_ids_{model, *this, sequence_lengths_unk} {
Expand Down
4 changes: 2 additions & 2 deletions src/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace Generators {
struct Llama_Model : Model {
Llama_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const ProviderOptions* provider_options);

std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) override;
std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) override;

std::unique_ptr<OrtSession> session_decoder_;
};

struct Llama_State : State {
Llama_State(Llama_Model& model, RoamingArray<int32_t> sequence_lengths, const SearchParams& params);
Llama_State(Llama_Model& model, RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params);
RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) override;

private:
Expand Down
4 changes: 2 additions & 2 deletions src/models/mistral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Mistral_Model::Mistral_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, co
InitDeviceAllocator(*session_decoder_);
}

std::unique_ptr<State> Mistral_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) {
std::unique_ptr<State> Mistral_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) {
return std::make_unique<Mistral_State>(*this, sequence_lengths, params);
}

Mistral_State::Mistral_State(Mistral_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const SearchParams& search_params)
Mistral_State::Mistral_State(Mistral_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const GeneratorParams& search_params)
: State{search_params},
model_{model},
position_ids_{model, *this, sequence_lengths_unk} {
Expand Down
4 changes: 2 additions & 2 deletions src/models/mistral.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace Generators {
struct Mistral_Model : Model {
Mistral_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const ProviderOptions* provider_options);

std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) override;
std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) override;

std::unique_ptr<OrtSession> session_decoder_;
};

struct Mistral_State : State {
Mistral_State(Mistral_Model& model, RoamingArray<int32_t> sequence_lengths, const SearchParams& params);
Mistral_State(Mistral_Model& model, RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params);
RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) override;

private:
Expand Down
Loading

0 comments on commit 06c294a

Please sign in to comment.