Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge Search & State into new "Generator" type #39

Merged
merged 7 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,45 @@
# model=og.Model("../../test_models/llama2-7b-fp32-cpu", device_type)
#model=og.Llama_Model("../../test_models/llama2-7b-fp16-gpu/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx", device_type)
#model=og.Llama_Model("../../test_models/llama2-7b-int4-gpu/rank_0_Llama-2-7b-hf_decoder_merged_model_int4.onnx", device_type)
model=og.Model("../../test_models/llama2-7b-chat-int4-gpu", device_type)
model=og.Model("../test_models/llama2-7b-chat-int4-gpu", device_type)
print("Model loaded")
# tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
tokenizer=model.CreateTokenizer()
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
# tokenizer=model.CreateTokenizer()
print("Tokenizer created")

# Keep asking for input prompts in an loop
while True:
text = input("Input:")
# input_tokens = tokenizer.encode(text, return_tensors='np')
input_tokens = tokenizer.encode(text)
input_tokens = tokenizer.encode(text, return_tensors='np')
# input_tokens = tokenizer.encode(text)

params=og.SearchParams(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please also change SearchParams to GeneratorParams?

params.max_length = 128
params.input_ids = input_tokens

search=params.CreateSearch()
state=model.CreateState(search.GetSequenceLengths(), params)
generator=og.Generator(model, params)

print("Output:")

print(text, end='', flush=True)
while not search.IsDone():
search.SetLogits(state.Run(search.GetSequenceLength(), search.GetNextTokens()))
while not generator.IsDone():
generator.ComputeLogits()

# search.Apply_MinLength(1)
# search.Apply_RepetitionPenalty(1.0)
Comment on lines 37 to 38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please update this also


search.SampleTopP(0.7, 0.6)
generator.AppendNextToken_TopP(0.7, 0.6)
Copy link
Member

@yufenglee yufenglee Jan 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Append doesn't look like a good name to me. generator actually generate next token. How about GenerateNextToken

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And can we make TopP a parameter instead of part of the name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I named it that since you first compute the logits, then append the next token based on the logits.
I initially had AddNextToken but 'Add' was less correct than 'Append' since it's appending tokens, not adding them.
Did you have a better name?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a process of generate next token. Add or Append doesn't reflect the action, i think.


print(tokenizer.decode([search.GetNextTokens().GetArray()[0]]), ' ', end='', flush=True)
'''
# print(tokenizer.decode([generator.GetNextTokens().GetArray()[0]]), ' ', end='', flush=True)
# Print each token as we compute it, we have to do some work to get newlines & spaces to appear properly:
word=tokenizer.convert_ids_to_tokens([search.GetNextTokens().GetArray()[0]])[0]
word=tokenizer.convert_ids_to_tokens([generator.GetNextTokens().GetArray()[0]])[0]
yufenglee marked this conversation as resolved.
Show resolved Hide resolved
if word=='<0x0A>':
word = '\n'
if word.startswith('▁'):
word = ' ' + word[1:]
print(word, end='', flush=True)
'''

# Print sequence all at once vs as it's decoded:
print(tokenizer.decode(search.GetSequence(0).GetArray()))
print(tokenizer.decode(generator.GetSequence(0).GetArray()))
print()
print()
2 changes: 1 addition & 1 deletion src/beam_search_scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void BeamHypotheses::Output(
}
}

BeamSearchScorer::BeamSearchScorer(const SearchParams& parameters)
BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters)
: batch_size_{parameters.batch_size},
num_beams_{parameters.num_beams},
max_length_{parameters.max_length},
Expand Down
2 changes: 1 addition & 1 deletion src/beam_search_scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct BeamHypotheses {
};

struct BeamSearchScorer {
BeamSearchScorer(const SearchParams& parameters);
BeamSearchScorer(const GeneratorParams& parameters);

void Process(Sequences& sequences,
std::span<const float> next_scores,
Expand Down
3 changes: 2 additions & 1 deletion src/beam_search_scorer_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#include "generators.h"
#include "search.h"
#include "search_cuda.h"
#include "beam_search_scorer_cuda.cuh"
#include "beam_search_scorer_cuda.h"

namespace Generators {

BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const SearchParams& parameters)
BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters)
: stream_{parameters.cuda_stream} {
state_cpu_ = CudaMallocHostArray<cuda::BeamScorerState>(1);
state_cpu_->batch_size_ = static_cast<size_t>(parameters.batch_size);
Expand Down
2 changes: 1 addition & 1 deletion src/beam_search_scorer_cuda.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
namespace Generators {

struct BeamSearchScorer_Cuda {
BeamSearchScorer_Cuda(const SearchParams& parameters);
BeamSearchScorer_Cuda(const GeneratorParams& parameters);

void Process(Sequences_Cuda& sequences,
std::span<const float> next_scores,
Expand Down
93 changes: 84 additions & 9 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ float Float16ToFloat32(uint16_t v) {
return std::ldexp((sign != 0 ? -1.0f : 1.0f) * (1.0f + static_cast<float>(fraction) / 1024.0f), exponent - 15);
}

SearchParams::SearchParams(const Model& model)
GeneratorParams::GeneratorParams(const Model& model)
: pad_token_id{model.config_->pad_token_id},
eos_token_id{model.config_->eos_token_id},
vocab_size{model.config_->model.vocab_size},
Expand Down Expand Up @@ -66,19 +66,94 @@ ProviderOptions GetDefaultProviderOptions([[maybe_unused]] DeviceType device_typ
return options;
}

std::unique_ptr<Search> SearchParams::CreateSearch() const {
std::unique_ptr<Generator> CreateGenerator(Model& model, const GeneratorParams& search_params) {
return std::make_unique<Generator>(model, search_params);
}

std::unique_ptr<Search> CreateSearch(const GeneratorParams& params) {
#if USE_CUDA
if (device_type == DeviceType::CUDA) {
if (num_beams > 1)
return std::make_unique<BeamSearch_Cuda>(*this);
return std::make_unique<GreedySearch_Cuda>(*this);
if (params.device_type == DeviceType::CUDA) {
if (params.num_beams > 1)
return std::make_unique<BeamSearch_Cuda>(params);
return std::make_unique<GreedySearch_Cuda>(params);
}
#endif

if (num_beams > 1) {
return std::make_unique<BeamSearch_Cpu>(*this);
if (params.num_beams > 1) {
return std::make_unique<BeamSearch_Cpu>(params);
}
return std::make_unique<GreedySearch_Cpu>(params);
}

Generator::Generator(Model& model, const GeneratorParams& search_params) : model_{model} {
search_ = CreateSearch(search_params);
state_ = model.CreateState(search_->GetSequenceLengths(), search_params);
}

void Generator::ComputeLogits() {
if (computed_logits_)
throw std::runtime_error("ComputeLogits called again without calling AppendNextToken* first");

search_->SetLogits(state_->Run(search_->GetSequenceLength(), search_->GetNextTokens(), search_->GetNextIndices()));
computed_logits_ = true;
}

bool Generator::IsDone() const {
if (computed_logits_)
throw std::runtime_error("IsDone() can't be called in the middle of processing logits");

return search_->IsDone();
}

void Generator::AppendNextToken_TopK_TopP(int top_k, float top_p, float temperature) {
if (search_->params_.num_beams != 1)
throw std::runtime_error("TopK and TopP cannot be used with a beam search");

if (!computed_logits_)
throw std::runtime_error("Must call ComputeLogits before AppendNextToken*");
computed_logits_ = false;

// TODO: Do TopK if top_k >1 then do TopP on the results
if (top_p < 1.0f) {
search_->SampleTopP(top_p, temperature);
} else if (top_k > 1) {
search_->SampleTopK(top_k, temperature);
} else {
search_->SelectTop();
}
}

void Generator::AppendNextToken() {
if (search_->params_.num_beams > 1) {
if (!computed_logits_)
throw std::runtime_error("Must call ComputeLogits before AppendNextToken*");
computed_logits_ = false;
search_->SelectTop();
return;
}

auto& config = *model_.config_;
AppendNextToken_TopK_TopP(config.top_k, config.top_p, config.temperature);
}

RoamingArray<int32_t> Generator::GetSequence(int index) {
return search_->GetSequence(index);
}

std::vector<int32_t> Generate(Model& model, const GeneratorParams& params) {
auto generator = CreateGenerator(model, params);

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->AppendNextToken();
}
return std::make_unique<GreedySearch_Cpu>(*this);

auto results = generator->search_->GetSequence(0);
auto results_cpu = results.GetCPU();

std::vector<int32_t> v;
v.assign(results_cpu.begin(), results_cpu.end());
return v;
}

} // namespace Generators
55 changes: 28 additions & 27 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

namespace Generators {
struct Model;
struct State;
struct Search;

// If we don't include cuda_runtime.h, we define this to avoid lots of extra #ifdefs
#ifndef USE_CUDA
Expand All @@ -52,33 +54,9 @@ using ProviderOptions = std::variant<

ProviderOptions GetDefaultProviderOptions(DeviceType device_type);

struct Search {
virtual ~Search() = default;

virtual RoamingArray<int32_t> GetNextTokens() = 0;
virtual RoamingArray<int32_t> GetNextIndices() {
throw std::runtime_error("GetNextIndices() can only be called for beam search, num_beams must be >1");
}
virtual RoamingArray<int32_t> GetSequenceLengths() = 0;
virtual int GetSequenceLength() const = 0;
virtual RoamingArray<int32_t> GetSequence(int index) = 0;

virtual void SetLogits(RoamingArray<float> logits) = 0;
virtual bool IsDone() const = 0;

// TODO: Beam Search only, this should be removed and made automatic
virtual void Finalize(size_t /*num_return_sequences*/, RoamingArray<int32_t> /*output*/, RoamingArray<float> /*sequence_scores*/) { assert(false); }

virtual void SelectTop() = 0;
virtual void SampleTopP(float /*p*/, float /*temperature*/) { assert(false); }
virtual void SampleTopK(int /*k*/, float /*temperature*/) { assert(false); }
};

struct SearchParams {
SearchParams() = default; // This constructor is only used if doing a custom model handler vs built-in
SearchParams(const Model& model);

std::unique_ptr<Search> CreateSearch() const;
struct GeneratorParams {
GeneratorParams() = default; // This constructor is only used if doing a custom model handler vs built-in
GeneratorParams(const Model& model);

// Values copied from config
int pad_token_id{};
Expand Down Expand Up @@ -124,6 +102,29 @@ struct SearchParams {
std::variant<Whisper> inputs;
};

struct Generator {
Generator(Model& model, const GeneratorParams& search_params);

bool IsDone() const;
void ComputeLogits();
void AppendNextToken_TopK_TopP(int top_k, float top_p, float temperature);
void AppendNextToken_TopP(float p, float temperature) { AppendNextToken_TopK_TopP(0, p, temperature); }
void AppendNextToken_TopK(int k, float temperature) { AppendNextToken_TopK_TopP(k, 1.0f, temperature); }
void AppendNextToken_Top() { AppendNextToken_TopK_TopP(1, 1.0f, 0.0f); }
void AppendNextToken();

RoamingArray<int32_t> GetSequence(int index);

Model& model_;
std::unique_ptr<State> state_;
std::unique_ptr<Search> search_;
bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio
};

std::unique_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, const ProviderOptions* provider_options = nullptr);
std::unique_ptr<Generator> CreateGenerator(Model& model, const GeneratorParams& search_params);
std::vector<int32_t> Generate(Model& model, const GeneratorParams& params); // Uses CreateGenerator and a simple loop to return the entire sequence

float Float16ToFloat32(uint16_t v); // v is a IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
void top_k_indices(std::span<int32_t> top_k, std::span<const float> inputs);

Expand Down
4 changes: 2 additions & 2 deletions src/models/gpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Gpt_Model::Gpt_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const Prov
InitDeviceAllocator(*session_decoder_);
}

std::unique_ptr<State> Gpt_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) {
std::unique_ptr<State> Gpt_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) {
return std::make_unique<Gpt_State>(*this, sequence_lengths, params);
}

Gpt_State::Gpt_State(Gpt_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const SearchParams& search_params)
Gpt_State::Gpt_State(Gpt_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const GeneratorParams& search_params)
: State{search_params},
model_{model},
position_ids_{model, *this, sequence_lengths_unk} {
Expand Down
4 changes: 2 additions & 2 deletions src/models/gpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace Generators {
struct Gpt_Model : Model {
Gpt_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const ProviderOptions* provider_options);

std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) override;
std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) override;

std::unique_ptr<OrtSession> session_decoder_;
};

struct Gpt_State : State {
Gpt_State(Gpt_Model& model, RoamingArray<int32_t> sequence_lengths, const SearchParams& search_params);
Gpt_State(Gpt_Model& model, RoamingArray<int32_t> sequence_lengths, const GeneratorParams& search_params);
RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) override;

private:
Expand Down
4 changes: 2 additions & 2 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Llama_Model::Llama_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const
InitDeviceAllocator(*session_decoder_);
}

std::unique_ptr<State> Llama_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) {
std::unique_ptr<State> Llama_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) {
return std::make_unique<Llama_State>(*this, sequence_lengths, params);
}

Llama_State::Llama_State(Llama_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const SearchParams& search_params)
Llama_State::Llama_State(Llama_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const GeneratorParams& search_params)
: State{search_params},
model_{model},
position_ids_{model, *this, sequence_lengths_unk} {
Expand Down
4 changes: 2 additions & 2 deletions src/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace Generators {
struct Llama_Model : Model {
Llama_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const ProviderOptions* provider_options);

std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) override;
std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) override;

std::unique_ptr<OrtSession> session_decoder_;
};

struct Llama_State : State {
Llama_State(Llama_Model& model, RoamingArray<int32_t> sequence_lengths, const SearchParams& params);
Llama_State(Llama_Model& model, RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params);
RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) override;

private:
Expand Down
4 changes: 2 additions & 2 deletions src/models/mistral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Mistral_Model::Mistral_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, co
InitDeviceAllocator(*session_decoder_);
}

std::unique_ptr<State> Mistral_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) {
std::unique_ptr<State> Mistral_Model::CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) {
return std::make_unique<Mistral_State>(*this, sequence_lengths, params);
}

Mistral_State::Mistral_State(Mistral_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const SearchParams& search_params)
Mistral_State::Mistral_State(Mistral_Model& model, RoamingArray<int32_t> sequence_lengths_unk, const GeneratorParams& search_params)
: State{search_params},
model_{model},
position_ids_{model, *this, sequence_lengths_unk} {
Expand Down
4 changes: 2 additions & 2 deletions src/models/mistral.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace Generators {
struct Mistral_Model : Model {
Mistral_Model(std::unique_ptr<Config> config, OrtEnv& ort_env, const ProviderOptions* provider_options);

std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const SearchParams& params) override;
std::unique_ptr<State> CreateState(RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params) override;

std::unique_ptr<OrtSession> session_decoder_;
};

struct Mistral_State : State {
Mistral_State(Mistral_Model& model, RoamingArray<int32_t> sequence_lengths, const SearchParams& params);
Mistral_State(Mistral_Model& model, RoamingArray<int32_t> sequence_lengths, const GeneratorParams& params);
RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) override;

private:
Expand Down
Loading
Loading