Skip to content

Commit

Permalink
Fix vision model graph capture not creating static buffers for embedd…
Browse files Browse the repository at this point in the history
…ing (#942)

This change essentially reverses the assignment of the embeddings
memory. Instead of creating the embeddings tensor in the embedding model
and pointing the embeddings of the text model to it, we now create the
embeddings tensor inside the text model and point the embeddings of the
embedding model to it.

The reason to do that is that the text model can possibly be in "graph
capture mode", which means that it allocates static buffers that it uses
between iterations, and even between generators. If we allocate the
memory in the embedding model and point the text model to it, the memory
will become invalid when the generator is destroyed and the captured
graph will exhibit undefined behavior (mostly spitting out garbage
output). But by pointing the embeddings output of the embedding model
towards the static buffer created by the text model, we can be certain
that the memory will stay alive for the duration of the model.

This PR doesn't change the behavior of the non-graph capture mode since
it really doesn't matter in that scenario whether the tensor is created
by the embedding model or the text model, but it fixes graph capture
usage for vision models.
  • Loading branch information
PatriceVignola authored Oct 7, 2024
1 parent c8e931d commit 77a88c3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 30 deletions.
38 changes: 21 additions & 17 deletions src/models/embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ Embeddings::Embeddings(State& state, Embeddings::Mode mode, const std::string& n
name_{name} {
// Embeddings are only transient inputs and outputs.
// They are never the user provided/requested model inputs/outputs
// So only create the transient output and reuse that ortvalue for subsequent
// So only create the transient input and reuse that ortvalue for previous
// steps in the pipeline.
if (mode == Embeddings::Mode::Output) {
if (mode == Embeddings::Mode::Input) {
if (state_.GetCapturedGraphInfo()) {
sb_embeddings_ = state_.GetCapturedGraphInfo()->sb_embeddings_.get();
}
Expand All @@ -30,44 +30,48 @@ Embeddings::Embeddings(State& state, Embeddings::Mode mode, const std::string& n
}

void Embeddings::Add() {
if (mode_ == Embeddings::Mode::Input) {
// In case the embeddings are input to a model, they are added
// as a nullptr to reserve a slot in the inputs. The embedding
// input will be overwritten when TransferState is invoked.
index_ = state_.inputs_.size();
state_.inputs_.push_back(nullptr);
state_.input_names_.push_back(name_.c_str());
} else {
if (mode_ == Embeddings::Mode::Output) {
// In case the embeddings are output of a model, they are added
// as a nullptr to reserve a slot in the outputs. The embedding
// output will be overwritten by the input of the following model
// when ReuseEmbeddingsBuffer is invoked. For example, if we have
// a pipeline that looks like EmbeddingModel -> TextModel, we
// create the embedding tensor in the TextModel as an input and
// simply reuse it in the EmbeddingModel as an output.
index_ = state_.outputs_.size();
state_.outputs_.push_back(embeddings_.get());
state_.outputs_.push_back(nullptr);
state_.output_names_.push_back(name_.c_str());
} else {
index_ = state_.inputs_.size();
state_.inputs_.push_back(embeddings_.get());
state_.input_names_.push_back(name_.c_str());
}
}

void Embeddings::UpdateSequenceLength() {
if (shape_[1] != 1) {
shape_[1] = 1;

if (mode_ == Embeddings::Mode::Output) {
if (mode_ == Embeddings::Mode::Input) {
if (!sb_embeddings_) {
embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
} else {
embeddings_ = sb_embeddings_->CreateTensorOnStaticBuffer(shape_, type_);
}

state_.outputs_[index_] = embeddings_.get();
state_.inputs_[index_] = embeddings_.get();
}
}
}

void Embeddings::ReuseEmbeddingsBuffer(const Embeddings& other) {
if (mode_ == Embeddings::Mode::Output ||
other.mode_ == Embeddings::Mode::Input) {
if (mode_ == Embeddings::Mode::Input ||
other.mode_ == Embeddings::Mode::Output) {
throw std::runtime_error("Incorrect usage of the embeddings inputs and outputs.");
}

// Share the output embeddings OrtValue* from other with the input embedding for this.
state_.inputs_[index_] = other.state_.outputs_[other.index_];
// Share the input embeddings OrtValue* from other with the output embedding for this.
state_.outputs_[index_] = other.state_.inputs_[other.index_];
}

} // namespace Generators
16 changes: 8 additions & 8 deletions src/models/multi_modal_vision_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ std::unique_ptr<State> MultiModalVisionModel::CreateState(RoamingArray<int32_t>
return std::make_unique<MultiModalPipelineState>(*this, sequence_lengths, params);
}

EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info, const int64_t num_image_tokens)
EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const int64_t num_image_tokens)
: State{params, model},
model_{model},
captured_graph_info_{captured_graph_info},
num_image_tokens_{num_image_tokens} {
input_ids_.Add();
image_features_.Add();
Expand All @@ -92,7 +91,6 @@ EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const Generat
void EmbeddingState::UpdateInputsAndOutputs(RoamingArray<int32_t> next_tokens) {
input_ids_.Update(next_tokens);
image_features_.Update();
inputs_embeds_.UpdateSequenceLength();
}

RoamingArray<float> EmbeddingState::Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) {
Expand Down Expand Up @@ -134,10 +132,11 @@ RoamingArray<float> DecoderState::Run(int current_length, RoamingArray<int32_t>
return logits_.Get();
}

void DecoderState::UpdateInputsOutputs(int current_length, RoamingArray<int32_t> beam_indices) {
void DecoderState::UpdateInputsAndOutputs(int current_length, RoamingArray<int32_t> beam_indices) {
position_inputs_.Update(current_length);
kv_cache_.Update(beam_indices.GetCPU(), current_length);
logits_.Update();
inputs_embeds_.UpdateSequenceLength();
}

MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& model,
Expand All @@ -147,7 +146,7 @@ MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& mo
model_{model},
num_image_tokens_{GetNumImageTokens(params_->extra_inputs, model_.config_->model.vision.inputs.pixel_values, model_.config_->model.vision.inputs.image_sizes)},
captured_graph_info_{model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)} {
embedding_state_ = std::make_unique<EmbeddingState>(model, params, nullptr, num_image_tokens_);
embedding_state_ = std::make_unique<EmbeddingState>(model, params, num_image_tokens_);
vision_state_ = std::make_unique<VisionState>(model_, params, num_image_tokens_);
decoder_state_ = std::make_unique<DecoderState>(model_, sequence_lengths_unk, params, captured_graph_info_.get());
}
Expand All @@ -167,9 +166,9 @@ RoamingArray<float> MultiModalPipelineState::Run(int current_length, RoamingArra
vision_state_->Run(current_length, next_tokens, next_indices);
}
embedding_state_->image_features_.ReuseImageFeaturesBuffer(vision_state_->image_features_);
embedding_state_->inputs_embeds_.ReuseEmbeddingsBuffer(decoder_state_->inputs_embeds_);
embedding_state_->Run(current_length, next_tokens, next_indices);

decoder_state_->inputs_embeds_.ReuseEmbeddingsBuffer(embedding_state_->inputs_embeds_);
auto logits = decoder_state_->Run(current_length, next_tokens, next_indices);

is_prompt_ = false;
Expand All @@ -179,10 +178,11 @@ RoamingArray<float> MultiModalPipelineState::Run(int current_length, RoamingArra
}

embedding_state_->UpdateInputsAndOutputs(next_tokens);
decoder_state_->UpdateInputsOutputs(current_length, next_indices);
decoder_state_->UpdateInputsAndOutputs(current_length, next_indices);

embedding_state_->inputs_embeds_.ReuseEmbeddingsBuffer(decoder_state_->inputs_embeds_);
embedding_state_->Run(current_length, next_tokens, next_indices);
decoder_state_->inputs_embeds_.ReuseEmbeddingsBuffer(embedding_state_->inputs_embeds_);

return decoder_state_->Run(current_length, next_tokens, next_indices);
}

Expand Down
7 changes: 2 additions & 5 deletions src/models/multi_modal_vision_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,19 @@ struct MultiModalVisionModel : Model {
};

struct EmbeddingState : State {
EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info, const int64_t num_image_tokens);
EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const int64_t num_image_tokens);
EmbeddingState(const EmbeddingState&) = delete;
EmbeddingState& operator=(const EmbeddingState&) = delete;

RoamingArray<float> Run(int current_length, RoamingArray<int32_t> next_tokens,
RoamingArray<int32_t> next_indices = {}) override;

const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_; };

private:
friend struct MultiModalPipelineState;

void UpdateInputsAndOutputs(RoamingArray<int32_t> next_tokens);

const MultiModalVisionModel& model_;
const CapturedGraphInfo* captured_graph_info_;
int64_t num_image_tokens_;

InputIDs input_ids_{*this}; // Model input
Expand Down Expand Up @@ -86,7 +83,7 @@ struct DecoderState : State {
private:
friend struct MultiModalPipelineState;

void UpdateInputsOutputs(int current_length, RoamingArray<int32_t> beam_indices);
void UpdateInputsAndOutputs(int current_length, RoamingArray<int32_t> beam_indices);

const MultiModalVisionModel& model_;
const CapturedGraphInfo* captured_graph_info_;
Expand Down

0 comments on commit 77a88c3

Please sign in to comment.