-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add re-designed support for Phi-3 vision and Phi-3.5 vision (#882)
### Description This PR supports the re-designed export of Phi-3 vision and Phi-3.5 vision. The new design natively supports multi-image and the `select` logic inside the ONNX models. ### Motivation and Context With the re-designed export, some of the logic inside ONNX Runtime GenAI is no longer needed as it is now inside the ONNX model. This allows other models to more easily re-use the vision and embedding components within ONNX Runtime GenAI.
- Loading branch information
1 parent
ad6a02f
commit b49e3b1
Showing
19 changed files
with
861 additions
and
709 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
#include "../generators.h" | ||
#include "model.h" | ||
#include "image_features.h" | ||
|
||
namespace Generators { | ||
|
||
ImageFeatures::ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens) | ||
: model_{model}, | ||
state_{state}, | ||
shape_{num_image_tokens, state_.params_->hidden_size}, | ||
type_{mode == ImageFeatures::Mode::Input | ||
? model_.session_info_->GetInputDataType(name) | ||
: model_.session_info_->GetOutputDataType(name)}, | ||
mode_{mode}, | ||
name_{name} { | ||
// There are four cases for ImageFeatures: | ||
// 1) Created as an output for vision model (num_image_tokens > 0) | ||
// The tensor needs to be pre-allocated to store the output. | ||
// It will be transferred to an input for the embedding model. | ||
// 2) Created as an output for vision model (num_image_tokens = 0) | ||
// The tensor will be pre-allocated to store the empty output. | ||
// It will be transferred to an input for the embedding model. | ||
// 3) Created as an input for embedding model (num_image_tokens > 0) | ||
// The tensor does not need to be pre-allocated because it will be created during (1). | ||
// 4) Created as an input for embedding model (num_image_tokens = 0) | ||
// The tensor does not need to be pre-allocated because it will be created during (2). | ||
if (mode == ImageFeatures::Mode::Output) { | ||
image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); | ||
} | ||
} | ||
|
||
void ImageFeatures::Add() { | ||
if (mode_ == ImageFeatures::Mode::Input) { | ||
// In case the image_features are an input to a model, they are added | ||
// as a nullptr to reserve a slot in the inputs. The image_features | ||
// input will be overwritten when ReuseImageFeaturesBuffer is invoked. | ||
index_ = state_.inputs_.size(); | ||
state_.inputs_.push_back(nullptr); | ||
state_.input_names_.push_back(name_.c_str()); | ||
} else { | ||
index_ = state_.outputs_.size(); | ||
state_.outputs_.push_back(image_features_.get()); | ||
state_.output_names_.push_back(name_.c_str()); | ||
} | ||
} | ||
|
||
void ImageFeatures::Update() { | ||
// Initialize empty image_features tensor for after-prompt input scenarios | ||
// num_image_tokens will be 0 when no image is provided | ||
if (shape_[0] > 0) { // if num_image_tokens > 0 | ||
shape_[0] = 0; | ||
image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); | ||
state_.inputs_[index_] = image_features_.get(); | ||
} | ||
} | ||
|
||
void ImageFeatures::ReuseImageFeaturesBuffer(ImageFeatures& other) { | ||
if (mode_ == ImageFeatures::Mode::Output || other.mode_ == ImageFeatures::Mode::Input) { | ||
throw std::runtime_error("Incorrect usage of the ImageFeatures inputs and outputs."); | ||
} | ||
|
||
// Share the output ImageFeatures OrtValue* from other with the input ImageFeatures for this. | ||
image_features_ = std::move(other.image_features_); | ||
state_.inputs_[index_] = other.state_.outputs_[other.index_]; | ||
} | ||
|
||
} // namespace Generators |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
namespace Generators { | ||
|
||
struct ImageFeatures { | ||
enum struct Mode { | ||
Input = 0, | ||
Output | ||
}; | ||
|
||
ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens); | ||
ImageFeatures(const ImageFeatures&) = delete; | ||
ImageFeatures& operator=(const ImageFeatures&) = delete; | ||
|
||
void Add(); | ||
void Update(); | ||
void ReuseImageFeaturesBuffer(ImageFeatures& other); | ||
|
||
auto& GetShape() const { return shape_; } | ||
OrtValue* Get() { return image_features_.get(); } | ||
|
||
private: | ||
const Model& model_; | ||
State& state_; | ||
|
||
std::array<int64_t, 2> shape_{}; // [num_image_tokens, hidden_size] | ||
ONNXTensorElementDataType type_; | ||
|
||
const Mode mode_{}; | ||
const std::string name_; | ||
|
||
std::unique_ptr<OrtValue> image_features_; | ||
size_t index_{~0U}; | ||
}; | ||
|
||
} // namespace Generators |
Oops, something went wrong.