diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp index baaaed345..b82f95d98 100644 --- a/src/models/embeddings.cpp +++ b/src/models/embeddings.cpp @@ -18,9 +18,9 @@ Embeddings::Embeddings(State& state, Embeddings::Mode mode, const std::string& n name_{name} { // Embeddings are only transient inputs and outputs. // They are never the user provided/requested model inputs/outputs - // So only create the transient output and reuse that ortvalue for subsequent + // So only create the transient input and reuse that ortvalue for previous // steps in the pipeline. - if (mode == Embeddings::Mode::Output) { + if (mode == Embeddings::Mode::Input) { if (state_.GetCapturedGraphInfo()) { sb_embeddings_ = state_.GetCapturedGraphInfo()->sb_embeddings_.get(); } @@ -30,17 +30,21 @@ Embeddings::Embeddings(State& state, Embeddings::Mode mode, const std::string& n } void Embeddings::Add() { - if (mode_ == Embeddings::Mode::Input) { - // In case the embeddings are input to a model, they are added - // as a nullptr to reserve a slot in the inputs. The embedding - // input will be overwritten when TransferState is invoked. - index_ = state_.inputs_.size(); - state_.inputs_.push_back(nullptr); - state_.input_names_.push_back(name_.c_str()); - } else { + if (mode_ == Embeddings::Mode::Output) { + // In case the embeddings are output of a model, they are added + // as a nullptr to reserve a slot in the outputs. The embedding + // output will be overwritten by the input of the following model + // when ReuseEmbeddingsBuffer is invoked. For example, if we have + // a pipeline that looks like EmbeddingModel -> TextModel, we + // create the embedding tensor in the TextModel as an input and + // simply reuse it in the EmbeddingModel as an output. index_ = state_.outputs_.size(); - state_.outputs_.push_back(embeddings_.get()); + state_.outputs_.push_back(nullptr); state_.output_names_.push_back(name_.c_str()); + } else { + index_ = state_.inputs_.size(); + state_.inputs_.push_back(embeddings_.get()); + state_.input_names_.push_back(name_.c_str()); } } @@ -48,26 +52,26 @@ void Embeddings::UpdateSequenceLength() { if (shape_[1] != 1) { shape_[1] = 1; - if (mode_ == Embeddings::Mode::Output) { + if (mode_ == Embeddings::Mode::Input) { if (!sb_embeddings_) { embeddings_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); } else { embeddings_ = sb_embeddings_->CreateTensorOnStaticBuffer(shape_, type_); } - state_.outputs_[index_] = embeddings_.get(); + state_.inputs_[index_] = embeddings_.get(); } } } void Embeddings::ReuseEmbeddingsBuffer(const Embeddings& other) { - if (mode_ == Embeddings::Mode::Output || - other.mode_ == Embeddings::Mode::Input) { + if (mode_ == Embeddings::Mode::Input || + other.mode_ == Embeddings::Mode::Output) { throw std::runtime_error("Incorrect usage of the embeddings inputs and outputs."); } - // Share the output embeddings OrtValue* from other with the input embedding for this. - state_.inputs_[index_] = other.state_.outputs_[other.index_]; + // Share the input embeddings OrtValue* from other with the output embedding for this. + state_.outputs_[index_] = other.state_.inputs_[other.index_]; } } // namespace Generators diff --git a/src/models/multi_modal_vision_model.cpp b/src/models/multi_modal_vision_model.cpp index 7bb95634a..dbcf5d4ed 100644 --- a/src/models/multi_modal_vision_model.cpp +++ b/src/models/multi_modal_vision_model.cpp @@ -79,10 +79,9 @@ std::unique_ptr MultiModalVisionModel::CreateState(RoamingArray return std::make_unique(*this, sequence_lengths, params); } -EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info, const int64_t num_image_tokens) +EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const int64_t num_image_tokens) : State{params, model}, model_{model}, - captured_graph_info_{captured_graph_info}, num_image_tokens_{num_image_tokens} { input_ids_.Add(); image_features_.Add(); @@ -92,7 +91,6 @@ EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const Generat void EmbeddingState::UpdateInputsAndOutputs(RoamingArray next_tokens) { input_ids_.Update(next_tokens); image_features_.Update(); - inputs_embeds_.UpdateSequenceLength(); } RoamingArray EmbeddingState::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { @@ -134,10 +132,11 @@ RoamingArray DecoderState::Run(int current_length, RoamingArray return logits_.Get(); } -void DecoderState::UpdateInputsOutputs(int current_length, RoamingArray beam_indices) { +void DecoderState::UpdateInputsAndOutputs(int current_length, RoamingArray beam_indices) { position_inputs_.Update(current_length); kv_cache_.Update(beam_indices.GetCPU(), current_length); logits_.Update(); + inputs_embeds_.UpdateSequenceLength(); } MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& model, @@ -147,7 +146,7 @@ MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& mo model_{model}, num_image_tokens_{GetNumImageTokens(params_->extra_inputs, model_.config_->model.vision.inputs.pixel_values, model_.config_->model.vision.inputs.image_sizes)}, captured_graph_info_{model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)} { - embedding_state_ = std::make_unique(model, params, nullptr, num_image_tokens_); + embedding_state_ = std::make_unique(model, params, num_image_tokens_); vision_state_ = std::make_unique(model_, params, num_image_tokens_); decoder_state_ = std::make_unique(model_, sequence_lengths_unk, params, captured_graph_info_.get()); } @@ -167,9 +166,9 @@ RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArra vision_state_->Run(current_length, next_tokens, next_indices); } embedding_state_->image_features_.ReuseImageFeaturesBuffer(vision_state_->image_features_); + embedding_state_->inputs_embeds_.ReuseEmbeddingsBuffer(decoder_state_->inputs_embeds_); embedding_state_->Run(current_length, next_tokens, next_indices); - decoder_state_->inputs_embeds_.ReuseEmbeddingsBuffer(embedding_state_->inputs_embeds_); auto logits = decoder_state_->Run(current_length, next_tokens, next_indices); is_prompt_ = false; @@ -179,10 +178,11 @@ RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArra } embedding_state_->UpdateInputsAndOutputs(next_tokens); - decoder_state_->UpdateInputsOutputs(current_length, next_indices); + decoder_state_->UpdateInputsAndOutputs(current_length, next_indices); + embedding_state_->inputs_embeds_.ReuseEmbeddingsBuffer(decoder_state_->inputs_embeds_); embedding_state_->Run(current_length, next_tokens, next_indices); - decoder_state_->inputs_embeds_.ReuseEmbeddingsBuffer(embedding_state_->inputs_embeds_); + return decoder_state_->Run(current_length, next_tokens, next_indices); } diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index fef1a9f36..65527105f 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -27,22 +27,19 @@ struct MultiModalVisionModel : Model { }; struct EmbeddingState : State { - EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info, const int64_t num_image_tokens); + EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const int64_t num_image_tokens); EmbeddingState(const EmbeddingState&) = delete; EmbeddingState& operator=(const EmbeddingState&) = delete; RoamingArray Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices = {}) override; - const CapturedGraphInfo* GetCapturedGraphInfo() const override { return captured_graph_info_; }; - private: friend struct MultiModalPipelineState; void UpdateInputsAndOutputs(RoamingArray next_tokens); const MultiModalVisionModel& model_; - const CapturedGraphInfo* captured_graph_info_; int64_t num_image_tokens_; InputIDs input_ids_{*this}; // Model input @@ -86,7 +83,7 @@ struct DecoderState : State { private: friend struct MultiModalPipelineState; - void UpdateInputsOutputs(int current_length, RoamingArray beam_indices); + void UpdateInputsAndOutputs(int current_length, RoamingArray beam_indices); const MultiModalVisionModel& model_; const CapturedGraphInfo* captured_graph_info_;