Skip to content

Commit

Permalink
Rename structs
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Dec 18, 2024
1 parent ae436ef commit 5a8fae9
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 63 deletions.
4 changes: 2 additions & 2 deletions src/models/decoder_only.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ struct DecoderOnly_State : State {

DefaultInputIDs input_ids_{*this};
Logits logits_{*this};
KeyValueCacheDefault kv_cache_{*this};
PositionInputsDefault position_inputs_;
DefaultKeyValueCache kv_cache_{*this};
DefaultPositionInputs position_inputs_;
ExtraInputs extra_inputs_{*this};
};

Expand Down
4 changes: 2 additions & 2 deletions src/models/gpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ struct Gpt_State : State {

DefaultInputIDs input_ids_{*this};
Logits logits_{*this};
KeyValueCacheDefault_Combined kv_cache_{*this};
PositionInputsDefault position_inputs_;
CombinedKeyValueCache kv_cache_{*this};
DefaultPositionInputs position_inputs_;
ExtraInputs extra_inputs_{*this};
};
} // namespace Generators
38 changes: 19 additions & 19 deletions src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ bool IsCacheNeeded(const Model& model) {

} // namespace

KeyValueCacheDefault_Combined::KeyValueCacheDefault_Combined(State& state)
CombinedKeyValueCache::CombinedKeyValueCache(State& state)
: state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers},
shape_{2, state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 0, model_.config_->model.decoder.head_size} {
Expand All @@ -54,7 +54,7 @@ KeyValueCacheDefault_Combined::KeyValueCacheDefault_Combined(State& state)
}
}

void KeyValueCacheDefault_Combined::Add() {
void CombinedKeyValueCache::Add() {
input_index_ = state_.inputs_.size();
output_index_ = state_.outputs_.size();

Expand All @@ -66,7 +66,7 @@ void KeyValueCacheDefault_Combined::Add() {
}
}

void KeyValueCacheDefault_Combined::Update(DeviceSpan<int32_t> beam_indices, int total_length) {
void CombinedKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int total_length) {
assert(state_.params_->search.num_beams == 1 || !beam_indices.empty()); // We require beam_indices if we're a beam search

if (!is_first_update_) {
Expand All @@ -89,7 +89,7 @@ void KeyValueCacheDefault_Combined::Update(DeviceSpan<int32_t> beam_indices, int
is_first_update_ = false;
}

void KeyValueCacheDefault_Combined::RewindTo(size_t index) {
void CombinedKeyValueCache::RewindTo(size_t index) {
if (shape_[3] <= static_cast<int>(index)) {
throw std::runtime_error("Requested length of rewind is greater than the current length.");
}
Expand All @@ -108,7 +108,7 @@ void KeyValueCacheDefault_Combined::RewindTo(size_t index) {
}

template <typename T>
void KeyValueCacheDefault_Combined::RewindPastTensorsTo(size_t index) {
void CombinedKeyValueCache::RewindPastTensorsTo(size_t index) {
assert(index > 0 && shape_[3] >= static_cast<int64_t>(index));
std::array<int64_t, 5> new_shape = shape_;
new_shape[3] = static_cast<int>(index);
Expand Down Expand Up @@ -143,7 +143,7 @@ void KeyValueCacheDefault_Combined::RewindPastTensorsTo(size_t index) {

// Copy present state to past state reordered by the beam_indices
template <typename ScoreType>
void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan<int32_t> beam_indices_device, int index) {
void CombinedKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices_device, int index) {
std::span<const int32_t> beam_indices = beam_indices_device.CopyDeviceToCpu();
auto block_size_per_beam = shape_[2] * shape_[3] * shape_[4];
auto past_key_size = shape_[1] * block_size_per_beam;
Expand Down Expand Up @@ -184,15 +184,15 @@ void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan<int32_t> beam_indic
pasts_[index] = std::move(past);
}

void KeyValueCacheDefault_Combined::PickPastState(DeviceSpan<int32_t> beam_indices, int index) {
void CombinedKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices, int index) {
if (type_ == Ort::TypeToTensorType<float>) {
PickPastState<float>(beam_indices, index);
} else {
PickPastState<Ort::Float16_t>(beam_indices, index);
}
}

KeyValueCacheDefault::KeyValueCacheDefault(State& state)
DefaultKeyValueCache::DefaultKeyValueCache(State& state)
: state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers},
past_present_share_buffer_{state_.params_->search.past_present_share_buffer && (state_.params_->search.num_beams == 1 || model_.config_->model.type == "whisper")},
Expand Down Expand Up @@ -257,7 +257,7 @@ KeyValueCacheDefault::KeyValueCacheDefault(State& state)
}
}

void KeyValueCacheDefault::AddEncoder() {
void DefaultKeyValueCache::AddEncoder() {
// We don't set the input_index_ & output_index_ because the encoder step only runs once, there's no update

for (int i = 0; i < layer_count_ * 2; ++i) {
Expand All @@ -266,7 +266,7 @@ void KeyValueCacheDefault::AddEncoder() {
}
}

void KeyValueCacheDefault::Add() {
void DefaultKeyValueCache::Add() {
input_index_ = state_.inputs_.size();
output_index_ = state_.outputs_.size();

Expand All @@ -285,7 +285,7 @@ void KeyValueCacheDefault::Add() {
}
}

void KeyValueCacheDefault::Update(DeviceSpan<int32_t> beam_indices, int total_length) {
void DefaultKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, int total_length) {
// If we're sharing past & present buffers there is nothing to do here, so early exit
if (past_present_share_buffer_)
return;
Expand All @@ -310,7 +310,7 @@ void KeyValueCacheDefault::Update(DeviceSpan<int32_t> beam_indices, int total_le
is_first_update_ = false;
}

void KeyValueCacheDefault::RewindTo(size_t index) {
void DefaultKeyValueCache::RewindTo(size_t index) {
if (past_present_share_buffer_) {
return;
} else if (shape_[2] <= static_cast<int>(index)) {
Expand All @@ -331,7 +331,7 @@ void KeyValueCacheDefault::RewindTo(size_t index) {
}

template <typename T>
void KeyValueCacheDefault::RewindPastTensorsTo(size_t index) {
void DefaultKeyValueCache::RewindPastTensorsTo(size_t index) {
assert(index > 0 && shape_[2] >= static_cast<int64_t>(index) && !past_present_share_buffer_);
std::array<int64_t, 4> new_shape = shape_;
new_shape[2] = static_cast<int>(index);
Expand Down Expand Up @@ -366,7 +366,7 @@ void KeyValueCacheDefault::RewindPastTensorsTo(size_t index) {

// Copy present state to past state reordered by the beam_indices
template <typename ScoreType>
void KeyValueCacheDefault::PickPastState(DeviceSpan<int32_t> beam_indices_device, int index) {
void DefaultKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices_device, int index) {
std::span<int32_t> beam_indices = beam_indices_device.Span();
auto block_size_per_beam = shape_[1] * shape_[2] * shape_[3];
auto element_count = shape_[0] * block_size_per_beam;
Expand Down Expand Up @@ -398,15 +398,15 @@ void KeyValueCacheDefault::PickPastState(DeviceSpan<int32_t> beam_indices_device
pasts_[index] = std::move(past_value);
}

void KeyValueCacheDefault::PickPastState(DeviceSpan<int32_t> beam_indices, int index) {
void DefaultKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices, int index) {
if (type_ == Ort::TypeToTensorType<float>) {
PickPastState<float>(beam_indices, index);
} else {
PickPastState<Ort::Float16_t>(beam_indices, index);
}
}

Cross_Cache::Cross_Cache(State& state)
CrossCache::CrossCache(State& state)
: state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers},
shape_{state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 1500, model_.config_->model.decoder.head_size} {
Expand All @@ -429,14 +429,14 @@ Cross_Cache::Cross_Cache(State& state)
}
}

void Cross_Cache::AddOutputs() {
void CrossCache::AddOutputs() {
for (int i = 0; i < layer_count_ * 2; ++i) {
state_.outputs_.push_back(values_[i].get());
state_.output_names_.push_back(output_name_strings_[i].c_str());
}
}

void Cross_Cache::AddInputs() {
void CrossCache::AddInputs() {
for (int i = 0; i < layer_count_ * 2; ++i) {
state_.inputs_.push_back(values_[i].get());
state_.input_names_.push_back(input_name_strings_[i].c_str());
Expand Down Expand Up @@ -680,7 +680,7 @@ std::unique_ptr<KeyValueCache> CreateKeyValueCache(State& state) {
if (state.model_.config_->model.decoder.sliding_window) {
return std::make_unique<WindowedKeyValueCache>(state);
} else {
return std::make_unique<KeyValueCacheDefault>(state);
return std::make_unique<DefaultKeyValueCache>(state);
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/models/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ struct KeyValueCache {
virtual void RewindTo(size_t index) = 0;
};

struct KeyValueCacheDefault_Combined : KeyValueCache {
KeyValueCacheDefault_Combined(State& state);
struct CombinedKeyValueCache : KeyValueCache {
CombinedKeyValueCache(State& state);

void Add() override; // Add to state inputs/outputs
void AddEncoder() override {
throw std::runtime_error("KeyValueCacheDefault_Combined does not support AddEncoder.");
throw std::runtime_error("CombinedKeyValueCache does not support AddEncoder.");
};
void Update(DeviceSpan<int32_t> beam_indices, int total_length) override;
void RewindTo(size_t index) override;
Expand Down Expand Up @@ -45,8 +45,8 @@ struct KeyValueCacheDefault_Combined : KeyValueCache {
std::vector<std::string> input_name_strings_, output_name_strings_;
};

struct KeyValueCacheDefault : KeyValueCache {
KeyValueCacheDefault(State& state);
struct DefaultKeyValueCache : KeyValueCache {
DefaultKeyValueCache(State& state);

void AddEncoder() override; // If model has an initial encoder step, this is used
// Register input_ids as ORT session input.
Expand Down Expand Up @@ -81,9 +81,9 @@ struct KeyValueCacheDefault : KeyValueCache {
std::vector<StaticBuffer*> sb_kv_caches_;
};

// Very similar to the KeyValueCacheDefault, but is only created once at the encoder step, then used without modification for every decoder step
struct Cross_Cache {
Cross_Cache(State& state);
// Very similar to the DefaultKeyValueCache, but is only created once at the encoder step, then used without modification for every decoder step
struct CrossCache {
CrossCache(State& state);

void AddOutputs();
void AddInputs();
Expand Down
4 changes: 2 additions & 2 deletions src/models/multi_modal_vision_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ struct DecoderState : State {
const CapturedGraphInfo* captured_graph_info_;
Embeddings inputs_embeds_{*this, Embeddings::Mode::Input, // Model input
model_.config_->model.decoder.inputs.embeddings};
PositionInputsDefault position_inputs_; // Model input
KeyValueCacheDefault kv_cache_{*this}; // Model input
DefaultPositionInputs position_inputs_; // Model input
DefaultKeyValueCache kv_cache_{*this}; // Model input
Logits logits_{*this}; // Model output
};

Expand Down
Loading

0 comments on commit 5a8fae9

Please sign in to comment.