diff --git a/examples/python/phi3v.py b/examples/python/phi3v.py index df5dddcfb..fd92dbd93 100644 --- a/examples/python/phi3v.py +++ b/examples/python/phi3v.py @@ -8,7 +8,6 @@ import onnxruntime_genai as og - def _complete(text, state): return (glob.glob(text + "*") + [None])[state] @@ -29,9 +28,10 @@ def run(args: argparse.Namespace): "Image Path (comma separated; leave empty if no image): " ).split(",") ] + image_paths = [image_path for image_path in image_paths if len(image_path)] print(image_paths) - image = None + images = None prompt = "<|user|>\n" if len(image_paths) == 0: print("No image provided") diff --git a/src/config.cpp b/src/config.cpp index 00708faae..ec426e5ce 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -137,10 +137,6 @@ struct Inputs_Element : JSON::Element { v_.position_ids = value; } else if (name == "attention_mask") { v_.attention_mask = value; - } else if (name == "seqlens_k") { - v_.seqlens_k = value; - } else if (name == "total_seq_len") { - v_.total_sequence_length = value; } else if (name == "past_key_names") { v_.past_key_names = value; } else if (name == "past_value_names") { @@ -248,8 +244,8 @@ struct VisionOutputs_Element : JSON::Element { explicit VisionOutputs_Element(Config::Model::Vision::Outputs& v) : v_{v} {} void OnString(std::string_view name, std::string_view value) override { - if (name == "visual_features") { - v_.visual_features = value; + if (name == "image_features") { + v_.image_features = value; } else throw JSON::unknown_value_error{}; } @@ -312,6 +308,8 @@ struct EmbeddingInputs_Element : JSON::Element { void OnString(std::string_view name, std::string_view value) override { if (name == "input_ids") { v_.input_ids = value; + } else if (name == "image_features") { + v_.image_features = value; } else throw JSON::unknown_value_error{}; } diff --git a/src/config.h b/src/config.h index 7263dbda3..f511c0885 100644 --- a/src/config.h +++ b/src/config.h @@ -12,6 +12,7 @@ struct Config { static constexpr std::string_view InputIdsName = "input_ids"; static constexpr std::string_view PixelValuesName = "pixel_values"; static constexpr std::string_view ImageSizesName = "image_sizes"; + static constexpr std::string_view ImageFeaturesName = "image_features"; }; fs::path config_path; // Path of the config directory @@ -62,6 +63,7 @@ struct Config { struct Inputs { std::string input_ids{Defaults::InputIdsName}; + std::string image_features{Defaults::ImageFeaturesName}; } inputs; struct Outputs { @@ -78,7 +80,7 @@ struct Config { } inputs; struct Outputs { - std::string visual_features{"visual_features"}; + std::string image_features{Defaults::ImageFeaturesName}; } outputs; } vision; @@ -97,8 +99,6 @@ struct Config { std::string embeddings{"inputs_embeds"}; std::string position_ids{"position_ids"}; std::string attention_mask{"attention_mask"}; - std::string seqlens_k{"seqlens_k"}; - std::string total_sequence_length{"total_seq_len"}; std::string past_key_names{"past_key_values.%d.key"}, past_value_names{"past_key_values.%d.value"}; std::string past_names; // When key/value pairs are combined std::string cross_past_key_names, cross_past_value_names; diff --git a/src/models/image_features.cpp b/src/models/image_features.cpp new file mode 100644 index 000000000..b5693a10a --- /dev/null +++ b/src/models/image_features.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "../generators.h" +#include "model.h" +#include "image_features.h" + +namespace Generators { + +ImageFeatures::ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens) + : model_{model}, + state_{state}, + shape_{num_image_tokens, state_.params_->hidden_size}, + type_{mode == ImageFeatures::Mode::Input + ? model_.session_info_->GetInputDataType(name) + : model_.session_info_->GetOutputDataType(name)}, + mode_{mode}, + name_{name} { + // There are four cases for ImageFeatures: + // 1) Created as an output for vision model (num_image_tokens > 0) + // The tensor needs to be pre-allocated to store the output. + // It will be transferred to an input for the embedding model. + // 2) Created as an output for vision model (num_image_tokens = 0) + // The tensor will be pre-allocated to store the empty output. + // It will be transferred to an input for the embedding model. + // 3) Created as an input for embedding model (num_image_tokens > 0) + // The tensor does not need to be pre-allocated because it will be created during (1). + // 4) Created as an input for embedding model (num_image_tokens = 0) + // The tensor does not need to be pre-allocated because it will be created during (2). + if (mode == ImageFeatures::Mode::Output) { + image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + } +} + +void ImageFeatures::Add() { + if (mode_ == ImageFeatures::Mode::Input) { + // In case the image_features are an input to a model, they are added + // as a nullptr to reserve a slot in the inputs. The image_features + // input will be overwritten when ReuseImageFeaturesBuffer is invoked. + index_ = state_.inputs_.size(); + state_.inputs_.push_back(nullptr); + state_.input_names_.push_back(name_.c_str()); + } else { + index_ = state_.outputs_.size(); + state_.outputs_.push_back(image_features_.get()); + state_.output_names_.push_back(name_.c_str()); + } +} + +void ImageFeatures::Update() { + // Initialize empty image_features tensor for after-prompt input scenarios + // num_image_tokens will be 0 when no image is provided + if (shape_[0] > 0) { // if num_image_tokens > 0 + shape_[0] = 0; + image_features_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + state_.inputs_[index_] = image_features_.get(); + } +} + +void ImageFeatures::ReuseImageFeaturesBuffer(ImageFeatures& other) { + if (mode_ == ImageFeatures::Mode::Output || other.mode_ == ImageFeatures::Mode::Input) { + throw std::runtime_error("Incorrect usage of the ImageFeatures inputs and outputs."); + } + + // Share the output ImageFeatures OrtValue* from other with the input ImageFeatures for this. + image_features_ = std::move(other.image_features_); + state_.inputs_[index_] = other.state_.outputs_[other.index_]; +} + +} // namespace Generators diff --git a/src/models/image_features.h b/src/models/image_features.h new file mode 100644 index 000000000..f75d8dfc0 --- /dev/null +++ b/src/models/image_features.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Generators { + +struct ImageFeatures { + enum struct Mode { + Input = 0, + Output + }; + + ImageFeatures(const Model& model, State& state, ImageFeatures::Mode mode, const std::string& name, int64_t num_image_tokens); + ImageFeatures(const ImageFeatures&) = delete; + ImageFeatures& operator=(const ImageFeatures&) = delete; + + void Add(); + void Update(); + void ReuseImageFeaturesBuffer(ImageFeatures& other); + + auto& GetShape() const { return shape_; } + OrtValue* Get() { return image_features_.get(); } + + private: + const Model& model_; + State& state_; + + std::array shape_{}; // [num_image_tokens, hidden_size] + ONNXTensorElementDataType type_; + + const Mode mode_{}; + const std::string name_; + + std::unique_ptr image_features_; + size_t index_{~0U}; +}; + +} // namespace Generators diff --git a/src/models/multi_modal_vision_model.cpp b/src/models/multi_modal_vision_model.cpp index ecfc01e6d..ddd5f2618 100644 --- a/src/models/multi_modal_vision_model.cpp +++ b/src/models/multi_modal_vision_model.cpp @@ -12,90 +12,16 @@ RoamingArray MakeDummy() { return RoamingArray(); } -#pragma warning(push) -#pragma warning(disable : 4189) // local variable is initialized but not referenced - -void Select(const Model& model, std::span input_ids, OrtValue* hidden_states, - OrtValue* visual_features, int32_t num_img_tokens, int32_t hidden_size, DeviceType device_type, - cudaStream_t cuda_stream) { - // Assme batch_size = 1 - constexpr int32_t min_input_id = -1000000000; - constexpr int64_t expected_batch_size = 1; - - // Find the position in the input_ids that corresponds to the start of the image tokens. - // Image tokens are represented by negative values in the input_ids. - const int64_t sequence_length = input_ids.size(); - int32_t image_position_start{}; - for (int64_t idx = 0; idx < sequence_length; ++idx) { - if (input_ids[idx] < 0 && input_ids[idx] > min_input_id) { - image_position_start = static_cast(idx); - break; - } - } - - // Replace the positions in the hidden_states tensor that correspond to the image tokens - // with the visual features tensor. - const int32_t start_pos = image_position_start * hidden_size; - const int32_t element_count = num_img_tokens * hidden_size; - const int32_t hidden_states_element_count = static_cast(sequence_length) * hidden_size; - - switch (device_type) { - case DeviceType::CPU: { - auto target = cpu_span(hidden_states->GetTensorMutableData(), hidden_states_element_count) - .subspan(start_pos, element_count); - auto source = cpu_span(visual_features->GetTensorMutableData(), element_count); - std::copy(source.begin(), source.end(), target.begin()); - break; - } -#if USE_CUDA - case DeviceType::CUDA: { - auto target = gpu_span(hidden_states->GetTensorMutableData(), hidden_states_element_count) - .subspan(start_pos, element_count); - auto source = gpu_span(visual_features->GetTensorMutableData(), element_count); - CudaCheck() == cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), - cudaMemcpyDeviceToDevice, cuda_stream); - break; - } -#endif - -#if USE_DML - case DeviceType::DML: { - ComPtr source_resource; - Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model.allocator_device_, visual_features->GetTensorMutableRawData(), &source_resource)); - - ComPtr target_resource; - Ort::ThrowOnError(model.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model.allocator_device_, hidden_states->GetTensorMutableRawData(), &target_resource)); - - model.GetDmlExecutionContext()->CopyBufferRegion( - target_resource.Get(), - start_pos * sizeof(uint16_t), - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - source_resource.Get(), - 0, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - element_count * sizeof(uint16_t)); - - // Execute the cached command list - ComPtr fence; - uint64_t completion_value; - model.GetDmlExecutionContext()->ExecuteCommandList(nullptr, &fence, &completion_value); - break; - } -#endif - default: - throw std::runtime_error("Unsupported device type for Select."); - } -} - -#pragma warning(pop) - int64_t GetNumImageTokens(const std::vector& extra_inputs, + const std::string& pixel_values_name, const std::string& image_sizes_name) { + std::shared_ptr pixel_values; std::shared_ptr image_sizes; for (size_t i = 0; i < extra_inputs.size(); ++i) { - if (extra_inputs[i].name == image_sizes_name) { + if (extra_inputs[i].name == pixel_values_name) { + pixel_values = extra_inputs[i].tensor; + } else if (extra_inputs[i].name == image_sizes_name) { image_sizes = extra_inputs[i].tensor; - break; } } @@ -104,41 +30,26 @@ int64_t GetNumImageTokens(const std::vector& extra_input return 0; } - if (image_sizes->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape() != std::vector{1, 2}) { - throw std::runtime_error("image_sizes tensor must have 2 elements"); + auto image_sizes_shape = image_sizes->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape(); + auto num_images = pixel_values->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape()[0]; + if (image_sizes_shape != std::vector{num_images, 2}) { + std::string wrong_image_sizes_shape = "("; + for (int i = 0; i < image_sizes_shape.size(); i++) { + wrong_image_sizes_shape += std::to_string(image_sizes_shape[i]); + std::string eos = (i != image_sizes_shape.size() - 1) ? ", " : ")"; + wrong_image_sizes_shape += eos; + } + throw std::runtime_error("image_sizes tensor must be of shape (num_images, 2), got " + wrong_image_sizes_shape); } auto image_sizes_data = image_sizes->ort_tensor_->GetTensorMutableData(); - const int64_t h = image_sizes_data[0] / 336; - const int64_t w = image_sizes_data[1] / 336; - return ((h * w + 1) * 144) + 1 + ((h + 1) * 12); -} - -std::unique_ptr GetVisualFeatures(OrtAllocator& device_allocator, const SessionInfo& session_info, - const std::string& visual_features_name, int32_t hidden_size, - int64_t num_image_tokens) { - constexpr int32_t batch_size = 1; - if (!session_info.HasOutput(visual_features_name)) { - throw std::runtime_error("Visual features output not found in the model"); - } - - auto type = session_info.GetOutputDataType(visual_features_name); - - std::vector shape = {batch_size, num_image_tokens, hidden_size}; - std::unique_ptr visual_features; - - switch (type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - visual_features = OrtValue::CreateTensor(device_allocator, shape); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - visual_features = OrtValue::CreateTensor(device_allocator, shape); - break; - default: - throw std::runtime_error("Unsupported data type for visual features: " + std::to_string(type)); + int64_t num_image_tokens = 0; + for (int i = 0; i < num_images; i++) { + int64_t h = image_sizes_data[i * num_images] / 336; + int64_t w = image_sizes_data[i * num_images + 1] / 336; + num_image_tokens += ((h * w + 1) * 144) + 1 + ((h + 1) * 12); } - - return visual_features; + return num_image_tokens; } } // namespace @@ -165,16 +76,19 @@ std::unique_ptr MultiModalVisionModel::CreateState(RoamingArray return std::make_unique(*this, sequence_lengths, params); } -EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info) +EmbeddingState::EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info, const int64_t num_image_tokens) : State{params, model}, model_{model}, - captured_graph_info_{captured_graph_info} { + captured_graph_info_{captured_graph_info}, + num_image_tokens_{num_image_tokens} { input_ids_.Add(); + image_features_.Add(); inputs_embeds_.Add(); } void EmbeddingState::UpdateInputsAndOutputs(RoamingArray next_tokens) { input_ids_.Update(next_tokens); + image_features_.Update(); inputs_embeds_.UpdateSequenceLength(); } @@ -185,22 +99,17 @@ RoamingArray EmbeddingState::Run(int current_length, RoamingArray(GetNumImageTokens(params_->extra_inputs, model_.config_->model.vision.inputs.image_sizes)); - if (num_image_tokens_ > 0) { - visual_features_ = GetVisualFeatures(*model_.allocator_device_, *model_.session_info_, - model_.config_->model.vision.outputs.visual_features, - params_->hidden_size, num_image_tokens_); - output_names_.push_back(model_.config_->model.vision.outputs.visual_features.c_str()); - outputs_.push_back(visual_features_.get()); - } + image_features_.Add(); } RoamingArray VisionState::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { - State::Run(*model_.vision_session_, *model_.run_options_, 1); + const int num_images = static_cast(inputs_[0]->GetTensorTypeAndShapeInfo()->GetShape()[0]); + State::Run(*model_.vision_session_, *model_.run_options_, num_images); return MakeDummy(); } @@ -233,33 +142,29 @@ MultiModalPipelineState::MultiModalPipelineState(const MultiModalVisionModel& mo const GeneratorParams& params) : State{params, model}, model_{model}, - captured_graph_info_{model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)}, - embedding_state_{std::make_unique(model, params, captured_graph_info_.get())}, - vision_state_{std::make_unique(model_, params)}, - decoder_state_{std::make_unique(model_, sequence_lengths_unk, params, captured_graph_info_.get())} { + num_image_tokens_{GetNumImageTokens(params_->extra_inputs, model_.config_->model.vision.inputs.pixel_values, model_.config_->model.vision.inputs.image_sizes)}, + captured_graph_info_{model.GetCapturedGraphPool()->ReserveCapturedGraph(model, params)} { + embedding_state_ = std::make_unique(model, params, captured_graph_info_.get(), num_image_tokens_); + vision_state_ = std::make_unique(model_, params, num_image_tokens_); + decoder_state_ = std::make_unique(model_, sequence_lengths_unk, params, captured_graph_info_.get()); } RoamingArray MultiModalPipelineState::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { // Pipeline state defines the pipeline of the execution of the models // Prompt stage: - // - input_ids -> |embeddings_model| -> |inputs_embeds| - // - pixel_values, img_sizes -> |vision_model| -> |inputs_embeds| - // - inputs_embeds, visual_features -> |Select| -> |inputs_embeds| - // - inputs_embeds -> |decoder_model| -> |logits| + // - pixel_values, image_sizes -> |vision_model| -> image_features + // - input_ids, image_features -> |embeddings_model| -> inputs_embeds + // - inputs_embeds -> |decoder_model| -> logits // Generation stage: - // - input_ids -> |embeddings_model| -> |inputs_embeds| - // - inputs_embeds -> |decoder_model| -> |logits| + // - input_ids, image_features -> |embeddings_model| -> inputs_embeds + // - inputs_embeds -> |decoder_model| -> logits if (is_prompt_) { - embedding_state_->Run(current_length, next_tokens, next_indices); - if (vision_state_->num_image_tokens_ > 0) { + if (num_image_tokens_ > 0) { vision_state_->Run(current_length, next_tokens, next_indices); - - // Run the select logic - Select(model_, params_->input_ids, embedding_state_->inputs_embeds_.Get(), - vision_state_->visual_features_.get(), vision_state_->num_image_tokens_, - params_->hidden_size, params_->device_type, params_->cuda_stream); } + embedding_state_->image_features_.ReuseImageFeaturesBuffer(vision_state_->image_features_); + embedding_state_->Run(current_length, next_tokens, next_indices); decoder_state_->inputs_embeds_.ReuseEmbeddingsBuffer(embedding_state_->inputs_embeds_); auto logits = decoder_state_->Run(current_length, next_tokens, next_indices); diff --git a/src/models/multi_modal_vision_model.h b/src/models/multi_modal_vision_model.h index 9b3e62646..34111c8b3 100644 --- a/src/models/multi_modal_vision_model.h +++ b/src/models/multi_modal_vision_model.h @@ -4,6 +4,7 @@ #pragma once #include "model.h" #include "input_ids.h" +#include "image_features.h" #include "embeddings.h" #include "extra_inputs.h" #include "logits.h" @@ -20,13 +21,13 @@ struct MultiModalVisionModel : Model { std::unique_ptr CreateState(RoamingArray sequence_lengths, const GeneratorParams& params) const override; - std::unique_ptr embedding_session_; // input_ids -> inputs_embeds - std::unique_ptr vision_session_; // pixel_values, img_sizes -> visual_features + std::unique_ptr vision_session_; // pixel_values, image_sizes -> image_features + std::unique_ptr embedding_session_; // input_ids, image_features -> inputs_embeds std::unique_ptr decoder_session_; // inputs_embeds, attention_mask, kv_cache -> logits }; struct EmbeddingState : State { - EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info); + EmbeddingState(const MultiModalVisionModel& model, const GeneratorParams& params, const CapturedGraphInfo* captured_graph_info, const int64_t num_image_tokens); EmbeddingState(const EmbeddingState&) = delete; EmbeddingState& operator=(const EmbeddingState&) = delete; @@ -42,13 +43,18 @@ struct EmbeddingState : State { const MultiModalVisionModel& model_; const CapturedGraphInfo* captured_graph_info_; - InputIDs input_ids_{model_, *this}; // Model input + int64_t num_image_tokens_; + + InputIDs input_ids_{model_, *this}; // Model input + ImageFeatures image_features_{model_, *this, ImageFeatures::Mode::Input, // Optional model input + model_.config_->model.embedding.inputs.image_features, + num_image_tokens_}; Embeddings inputs_embeds_{model_, *this, Embeddings::Mode::Output, // Model output model_.config_->model.embedding.outputs.embeddings}; }; struct VisionState : State { - VisionState(const MultiModalVisionModel& model, const GeneratorParams& params); + VisionState(const MultiModalVisionModel& model, const GeneratorParams& params, const int64_t num_image_tokens); VisionState(const VisionState&) = delete; VisionState& operator=(const VisionState&) = delete; @@ -59,9 +65,11 @@ struct VisionState : State { friend struct MultiModalPipelineState; const MultiModalVisionModel& model_; - ExtraInputs extra_inputs_{model_, *this}; // Model inputs - std::unique_ptr visual_features_; // Model output - int32_t num_image_tokens_{}; + int64_t num_image_tokens_; + ExtraInputs extra_inputs_{model_, *this}; // Model inputs + ImageFeatures image_features_{model_, *this, ImageFeatures::Mode::Output, // Model output + model_.config_->model.vision.outputs.image_features, + num_image_tokens_}; }; struct DecoderState : State { @@ -103,6 +111,7 @@ struct MultiModalPipelineState : State { int current_length); const MultiModalVisionModel& model_; + int64_t num_image_tokens_{0}; const CapturedGraphInfoPtr captured_graph_info_; std::unique_ptr embedding_state_; std::unique_ptr vision_state_; diff --git a/src/models/prompt_image_processor.cpp b/src/models/prompt_image_processor.cpp index 8b791af1e..a56a66362 100644 --- a/src/models/prompt_image_processor.cpp +++ b/src/models/prompt_image_processor.cpp @@ -15,7 +15,7 @@ std::unique_ptr ProcessImagePrompt(const Generators::Tokenizer& tokeni const size_t num_images = num_img_tokens ? num_img_tokens->NumberOfElement() : 0U; auto* num_img_tokens_data = num_img_tokens ? num_img_tokens->Data() : nullptr; - // Split the prompt string based on the occurrences of the pattern "<|image_|>" + // Split the prompt string based on the occurrences of the pattern "<|image_|>" // Here the represents the image id. const std::regex pattern("<\\|image_\\d+\\|>"); const std::vector prompt_chunks( diff --git a/test/test_models/create_dummy_model.py b/test/test_models/create_dummy_model.py new file mode 100644 index 000000000..117354115 --- /dev/null +++ b/test/test_models/create_dummy_model.py @@ -0,0 +1,130 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +""" +Create dummy ONNX models that contain only inputs and outputs. +This is helpful for creating ONNX models to run simple API tests (e.g. pre-processing) +where the contents of the ONNX models don't matter. + +Example usage: +1) python3 create_dummy_model.py \ + --inputs "pixel_values; TensorProto.FLOAT16; ['num_images', 'max_num_crops', 3, 'height', 'width']" "image_sizes; TensorProto.INT64; ['num_images', 2]" \ + --outputs "image_features; TensorProto.FLOAT16; ['num_image_tokens', 3072]" \ + --filename "dummy_vision.onnx" +2) python3 create_dummy_model.py \ + --inputs "input_ids; TensorProto.INT64; ['batch_size', 'sequence_length']" "image_features; TensorProto.FLOAT16; ['num_image_tokens', 3072]" \ + --outputs "inputs_embeds; TensorProto.FLOAT; ['batch_size', 'sequence_length', 3072]" \ + --filename "dummy_embedding.onnx" +3) python3 create_dummy_model.py \ + --inputs "inputs_embeds; TensorProto.FLOAT; ['batch_size', 'sequence_length', 3072]" "attention_mask; TensorProto.INT64; ['batch_size', 'total_sequence_length']" "past_key_values.0.key; TensorProto.FLOAT; ['batch_size', 32, 'past_sequence_length', 96]" "past_key_values.0.value; TensorProto.FLOAT; ['batch_size', 32, 'past_sequence_length', 96]" \ + --outputs "logits; TensorProto.FLOAT; ['batch_size', 'sequence_length', 32064]" "present.0.key; TensorProto.FLOAT; ['batch_size', 32, 'total_sequence_length', 96]" "present.0.value; TensorProto.FLOAT; ['batch_size', 32, 'total_sequence_length', 96]" \ + --filename "dummy_text.onnx" +""" + +import argparse +import numpy as np +import onnx +from onnx import helper, numpy_helper, TensorProto + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--inputs", + metavar="(NAME; DTYPE; SHAPE)", + nargs='+', + help="Inputs of the form '(input_name; input_dtype; input_shape)' for model" + ) + parser.add_argument( + "-o", + "--outputs", + metavar="(NAME; DTYPE; SHAPE)", + nargs='+', + help="Outputs of the form '(output_name; output_dtype; output_shape)' for model" + ) + parser.add_argument( + "-f", + "--filename", + type=str, + help="Filename to save dummy model as", + ) + + args = parser.parse_args() + return args + +def parse_args(input_or_output): + list_of_inputs_or_outputs = [] + for input_str in input_or_output: + input_or_output_to_add = input_str.split("; ") + input_or_output_to_add = [elm.strip() for elm in input_or_output_to_add] + list_of_inputs_or_outputs.append(input_or_output_to_add) + return list_of_inputs_or_outputs + +def get_input_or_output_value_infos(input_or_outputs): + value_infos = [] + for input_or_output in input_or_outputs: + print(input_or_output) + name, dtype, shape = input_or_output[0], eval(input_or_output[1]), eval(input_or_output[2]) + value_info = helper.make_tensor_value_info(name, dtype, shape) + value_infos.append(value_info) + return value_infos + +def get_dummy_tensor_shape(shape): + np_shape = () + for dim in shape: + if type(dim) == str: + np_shape += (2,) + elif type(dim) == int: + np_shape += (dim,) + else: + raise NotImplementedError(f"Unknown dim type: {type(dim)}") + return np_shape + +def get_output_initializers(outputs): + initializers = [] + for output in outputs: + name, dtype, shape = output[0], eval(output[1]), eval(output[2]) + np_shape = get_dummy_tensor_shape(shape) + np_dtype = to_numpy_dtype[dtype] + tensor = numpy_helper.from_array(np.zeros(np_shape, dtype=np_dtype)) + tensor.name = name + initializers.append(tensor) + return initializers + +def main(): + args = get_args() + args.inputs = parse_args(args.inputs) + args.outputs = parse_args(args.outputs) + + # Create dummy model + model = helper.make_model( + opset_imports=[helper.make_operatorsetid('', 14)], + ir_version=7, + producer_name="onnxruntime-genai", + producer_version="0.0.0", + graph=helper.make_graph( + name="main_graph", + inputs=get_input_or_output_value_infos(args.inputs), + outputs=get_input_or_output_value_infos(args.outputs), + initializer=get_output_initializers(args.outputs), + value_info=[], + nodes=[], + ) + ) + onnx.save_model( + model, + args.filename, + ) + +if __name__ == "__main__": + # Map TensorProto dtypes to NumPy dtypes + to_numpy_dtype = { + TensorProto.INT8: np.uint8, + TensorProto.INT32: np.int32, + TensorProto.INT64: np.int64, + TensorProto.FLOAT16: np.float16, + TensorProto.FLOAT: np.float32, + } + main() diff --git a/test/test_models/images/10809054.jpg b/test/test_models/images/10809054.jpg new file mode 100644 index 000000000..117ca64be Binary files /dev/null and b/test/test_models/images/10809054.jpg differ diff --git a/test/test_models/vision-preprocessing/dummy.onnx b/test/test_models/vision-preprocessing/dummy.onnx deleted file mode 100644 index c47a93275..000000000 Binary files a/test/test_models/vision-preprocessing/dummy.onnx and /dev/null differ diff --git a/test/test_models/vision-preprocessing/dummy_embedding.onnx b/test/test_models/vision-preprocessing/dummy_embedding.onnx new file mode 100644 index 000000000..9059638bd Binary files /dev/null and b/test/test_models/vision-preprocessing/dummy_embedding.onnx differ diff --git a/test/test_models/vision-preprocessing/dummy_text.onnx b/test/test_models/vision-preprocessing/dummy_text.onnx new file mode 100644 index 000000000..8b731bf5c Binary files /dev/null and b/test/test_models/vision-preprocessing/dummy_text.onnx differ diff --git a/test/test_models/vision-preprocessing/dummy_vision.onnx b/test/test_models/vision-preprocessing/dummy_vision.onnx new file mode 100644 index 000000000..981bb31be Binary files /dev/null and b/test/test_models/vision-preprocessing/dummy_vision.onnx differ diff --git a/test/test_models/vision-preprocessing/genai_config.json b/test/test_models/vision-preprocessing/genai_config.json index bb5852720..f05d77ea0 100644 --- a/test/test_models/vision-preprocessing/genai_config.json +++ b/test/test_models/vision-preprocessing/genai_config.json @@ -1,73 +1,69 @@ -{ - "model": { - "bos_token_id": 1, - "context_length": 131072, - "decoder": { - "session_options": { - "log_id": "onnxruntime-genai", - "provider_options": [] - }, - "filename": "dummy.onnx", - "head_size": 96, - "hidden_size": 3072, - "inputs": { - "inputs_embeds": "inputs_embeds", - "attention_mask": "attention_mask", - "past_key_names": "past_key_values.%d.key", - "past_value_names": "past_key_values.%d.value" - }, - "outputs": { - "logits": "logits", - "present_key_names": "present.%d.key", - "present_value_names": "present.%d.value" - }, - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 32 - }, - "embedding": { - "filename": "dummy.onnx", - "inputs": { - "input_ids": "input_ids" - }, - "outputs": { - "inputs_embeds": "inputs_embeds" - } - }, - "vision": { - "filename": "dummy.onnx", - "inputs": { - "pixel_values": "past_0", - "image_sizes": "image_sizes" - }, - "outputs": { - "visual_features": "visual_features" - } - }, - "eos_token_id": [ - 2, - 32000, - 32001, - 32007 - ], - "pad_token_id": 32000, - "type": "phi3v", - "vocab_size": 32064 - }, - "search": { - "diversity_penalty": 0.0, - "do_sample": false, - "early_stopping": true, - "length_penalty": 1.0, - "max_length": 131072, - "min_length": 0, - "no_repeat_ngram_size": 0, - "num_beams": 1, - "num_return_sequences": 1, - "past_present_share_buffer": true, - "repetition_penalty": 1.0, - "temperature": 1.0, - "top_k": 1, - "top_p": 1.0 - } +{ + "model": { + "bos_token_id": 1, + "context_length": 131072, + "decoder": { + "session_options": { + "log_id": "onnxruntime-genai", + "provider_options": [] + }, + "filename": "dummy_text.onnx", + "head_size": 96, + "hidden_size": 3072, + "inputs": { + "inputs_embeds": "inputs_embeds", + "attention_mask": "attention_mask", + "past_key_names": "past_key_values.%d.key", + "past_value_names": "past_key_values.%d.value" + }, + "outputs": { + "logits": "logits", + "present_key_names": "present.%d.key", + "present_value_names": "present.%d.value" + }, + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 32 + }, + "embedding": { + "filename": "dummy_embedding.onnx", + "inputs": { + "input_ids": "input_ids", + "image_features": "image_features" + }, + "outputs": { + "inputs_embeds": "inputs_embeds" + } + }, + "vision": { + "filename": "dummy_vision.onnx", + "inputs": { + "pixel_values": "pixel_values", + "image_sizes": "image_sizes" + }, + "outputs": { + "image_features": "image_features" + } + }, + "eos_token_id": 32007, + "pad_token_id": 32000, + "type": "phi3v", + "vocab_size": 32064 + }, + "search": { + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": true, + "length_penalty": 1.0, + "max_length": 131072, + "min_length": 0, + "no_repeat_ngram_size": 0, + "num_beams": 1, + "num_return_sequences": 1, + "past_present_share_buffer": true, + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_k": 1, + "top_p": 1.0 + } } \ No newline at end of file diff --git a/test/test_models/vision-preprocessing/processor_config.json b/test/test_models/vision-preprocessing/processor_config.json index c217a6003..22adf04d8 100644 --- a/test/test_models/vision-preprocessing/processor_config.json +++ b/test/test_models/vision-preprocessing/processor_config.json @@ -25,7 +25,7 @@ "domain": "com.microsoft.extensions", "type": "Phi3ImageTransform", "attrs": { - "num_crops": 16, + "num_crops": 4, "num_img_tokens": 144 } } diff --git a/test/test_models/vision-preprocessing/special_tokens_map.json b/test/test_models/vision-preprocessing/special_tokens_map.json index f3809994e..0616f20a3 100644 --- a/test/test_models/vision-preprocessing/special_tokens_map.json +++ b/test/test_models/vision-preprocessing/special_tokens_map.json @@ -1,36 +1,36 @@ -{ - "additional_special_tokens": [ - "<|system|>", - "<|end|>", - "<|user|>", - "<|end|>" - ], - "bos_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "eos_token": { - "content": "<|endoftext|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "pad_token": { - "content": "<|endoftext|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "unk_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - } -} +{ + "additional_special_tokens": [ + "<|system|>", + "<|end|>", + "<|user|>", + "<|end|>" + ], + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "<|endoftext|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "<|endoftext|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/test/test_models/vision-preprocessing/tokenizer.json b/test/test_models/vision-preprocessing/tokenizer.json index 93333dd68..7f2088c90 100644 --- a/test/test_models/vision-preprocessing/tokenizer.json +++ b/test/test_models/vision-preprocessing/tokenizer.json @@ -131,7 +131,7 @@ }, { "id": 32011, - "content": "<|step|>", + "content": "<|placeholder7|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -140,7 +140,7 @@ }, { "id": 32012, - "content": "<|function_output|>", + "content": "<|placeholder8|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -149,7 +149,7 @@ }, { "id": 32013, - "content": "<|tag|>", + "content": "<|placeholder9|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -158,7 +158,7 @@ }, { "id": 32014, - "content": "<|function_call|>", + "content": "<|placeholder10|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -167,7 +167,7 @@ }, { "id": 32015, - "content": "<|raw|>", + "content": "<|placeholder11|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -176,7 +176,7 @@ }, { "id": 32016, - "content": "<|continue|>", + "content": "<|placeholder12|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -185,7 +185,7 @@ }, { "id": 32017, - "content": "<|function_list|>", + "content": "<|placeholder13|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -194,7 +194,7 @@ }, { "id": 32018, - "content": "<|calc|>", + "content": "<|placeholder14|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -203,7 +203,7 @@ }, { "id": 32019, - "content": "<|code|>", + "content": "<|placeholder15|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -212,7 +212,7 @@ }, { "id": 32020, - "content": "<|/code|>", + "content": "<|placeholder16|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -221,7 +221,7 @@ }, { "id": 32021, - "content": "<|summary|>", + "content": "<|placeholder17|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -230,7 +230,7 @@ }, { "id": 32022, - "content": "<|resource|>", + "content": "<|placeholder18|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -239,7 +239,7 @@ }, { "id": 32023, - "content": "<|assistant_mask|>", + "content": "<|placeholder19|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -248,7 +248,7 @@ }, { "id": 32024, - "content": "<|start|>", + "content": "<|placeholder20|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -257,7 +257,7 @@ }, { "id": 32025, - "content": "<|message|>", + "content": "<|placeholder21|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -266,7 +266,7 @@ }, { "id": 32026, - "content": "<|fim_prefix|>", + "content": "<|placeholder22|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -275,7 +275,7 @@ }, { "id": 32027, - "content": "<|fim_middle|>", + "content": "<|placeholder23|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -284,7 +284,7 @@ }, { "id": 32028, - "content": "<|fim_suffix|>", + "content": "<|placeholder24|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -293,7 +293,7 @@ }, { "id": 32029, - "content": "<|meta_start|>", + "content": "<|placeholder25|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -302,7 +302,7 @@ }, { "id": 32030, - "content": "<|ipynb_marker|>", + "content": "<|placeholder26|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -311,7 +311,7 @@ }, { "id": 32031, - "content": "<|diff_marker|>", + "content": "<|placeholder27|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -320,7 +320,7 @@ }, { "id": 32032, - "content": "<|ghissue|>", + "content": "<|placeholder28|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -329,7 +329,7 @@ }, { "id": 32033, - "content": "<|ghreview|>", + "content": "<|placeholder29|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -338,7 +338,7 @@ }, { "id": 32034, - "content": "<|disc_start|>", + "content": "<|placeholder30|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -347,7 +347,7 @@ }, { "id": 32035, - "content": "<|disc_sep|>", + "content": "<|placeholder31|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -356,7 +356,7 @@ }, { "id": 32036, - "content": "<|disc_thread|><|query|>", + "content": "<|placeholder32|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -365,7 +365,7 @@ }, { "id": 32037, - "content": "<|/query|>", + "content": "<|placeholder33|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -374,7 +374,7 @@ }, { "id": 32038, - "content": "<|data|>", + "content": "<|placeholder34|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -383,7 +383,7 @@ }, { "id": 32039, - "content": "<|/data|>", + "content": "<|placeholder35|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -392,7 +392,7 @@ }, { "id": 32040, - "content": "<|sys|>", + "content": "<|placeholder36|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -401,7 +401,7 @@ }, { "id": 32041, - "content": "<|/sys|>", + "content": "<|placeholder37|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -410,7 +410,7 @@ }, { "id": 32042, - "content": "<|inst|>", + "content": "<|placeholder38|>", "single_word": false, "lstrip": false, "rstrip": true, @@ -419,7 +419,7 @@ }, { "id": 32043, - "content": "<|/inst|>", + "content": "<|placeholder39|>", "single_word": false, "lstrip": false, "rstrip": true, diff --git a/test/test_models/vision-preprocessing/tokenizer_config.json b/test/test_models/vision-preprocessing/tokenizer_config.json index 995bec9e0..bda28a156 100644 --- a/test/test_models/vision-preprocessing/tokenizer_config.json +++ b/test/test_models/vision-preprocessing/tokenizer_config.json @@ -1,407 +1,413 @@ -{ - "add_bos_token": true, - "add_eos_token": false, - "added_tokens_decoder": { - "0": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false, - "special": true - }, - "1": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false, - "special": true - }, - "2": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": false - }, - "32000": { - "content": "<|endoftext|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false, - "special": true - }, - "32001": { - "content": "<|assistant|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32002": { - "content": "<|placeholder1|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32003": { - "content": "<|placeholder2|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32004": { - "content": "<|placeholder3|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32005": { - "content": "<|placeholder4|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32006": { - "content": "<|system|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false, - "special": true - }, - "32007": { - "content": "<|end|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false, - "special": true - }, - "32008": { - "content": "<|placeholder5|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32009": { - "content": "<|placeholder6|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32010": { - "content": "<|user|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false, - "special": true - }, - "32011": { - "content": "<|step|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32012": { - "content": "<|function_output|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32013": { - "content": "<|tag|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32014": { - "content": "<|function_call|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32015": { - "content": "<|raw|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32016": { - "content": "<|continue|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32017": { - "content": "<|function_list|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32018": { - "content": "<|calc|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32019": { - "content": "<|code|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32020": { - "content": "<|/code|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32021": { - "content": "<|summary|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32022": { - "content": "<|resource|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32023": { - "content": "<|assistant_mask|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32024": { - "content": "<|start|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32025": { - "content": "<|message|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32026": { - "content": "<|fim_prefix|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32027": { - "content": "<|fim_middle|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32028": { - "content": "<|fim_suffix|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32029": { - "content": "<|meta_start|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32030": { - "content": "<|ipynb_marker|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32031": { - "content": "<|diff_marker|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32032": { - "content": "<|ghissue|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32033": { - "content": "<|ghreview|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32034": { - "content": "<|disc_start|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32035": { - "content": "<|disc_sep|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32036": { - "content": "<|disc_thread|><|query|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32037": { - "content": "<|/query|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32038": { - "content": "<|data|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32039": { - "content": "<|/data|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32040": { - "content": "<|sys|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32041": { - "content": "<|/sys|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32042": { - "content": "<|inst|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32043": { - "content": "<|/inst|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - }, - "32044": { - "content": "<|image|>", - "lstrip": false, - "normalized": false, - "rstrip": true, - "single_word": false, - "special": true - } - }, - "additional_special_tokens": [ - "<|system|>", - "<|end|>", - "<|user|>", - "<|end|>" - ], - "bos_token": "", - "chat_template": "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", - "clean_up_tokenization_spaces": false, - "eos_token": "<|endoftext|>", - "model_max_length": 131072, - "pad_token": "<|endoftext|>", - "padding_side": "right", - "sp_model_kwargs": {}, - "tokenizer_class": "LlamaTokenizer", - "unk_token": "", - "use_default_system_prompt": false -} +{ + "add_bos_token": true, + "add_eos_token": false, + "add_prefix_space": null, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": false + }, + "32000": { + "content": "<|endoftext|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32001": { + "content": "<|assistant|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32002": { + "content": "<|placeholder1|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32003": { + "content": "<|placeholder2|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32004": { + "content": "<|placeholder3|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32005": { + "content": "<|placeholder4|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32006": { + "content": "<|system|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32007": { + "content": "<|end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32008": { + "content": "<|placeholder5|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32009": { + "content": "<|placeholder6|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32010": { + "content": "<|user|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32011": { + "content": "<|placeholder7|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32012": { + "content": "<|placeholder8|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32013": { + "content": "<|placeholder9|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32014": { + "content": "<|placeholder10|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32015": { + "content": "<|placeholder11|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32016": { + "content": "<|placeholder12|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32017": { + "content": "<|placeholder13|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32018": { + "content": "<|placeholder14|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32019": { + "content": "<|placeholder15|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32020": { + "content": "<|placeholder16|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32021": { + "content": "<|placeholder17|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32022": { + "content": "<|placeholder18|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32023": { + "content": "<|placeholder19|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32024": { + "content": "<|placeholder20|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32025": { + "content": "<|placeholder21|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32026": { + "content": "<|placeholder22|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32027": { + "content": "<|placeholder23|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32028": { + "content": "<|placeholder24|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32029": { + "content": "<|placeholder25|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32030": { + "content": "<|placeholder26|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32031": { + "content": "<|placeholder27|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32032": { + "content": "<|placeholder28|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32033": { + "content": "<|placeholder29|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32034": { + "content": "<|placeholder30|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32035": { + "content": "<|placeholder31|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32036": { + "content": "<|placeholder32|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32037": { + "content": "<|placeholder33|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32038": { + "content": "<|placeholder34|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32039": { + "content": "<|placeholder35|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32040": { + "content": "<|placeholder36|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32041": { + "content": "<|placeholder37|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32042": { + "content": "<|placeholder38|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32043": { + "content": "<|placeholder39|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32044": { + "content": "<|image|>", + "lstrip": false, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "<|system|>", + "<|end|>", + "<|user|>", + "<|end|>" + ], + "auto_map": { + "AutoProcessor": "processing_phi3_v.Phi3VProcessor" + }, + "bos_token": "", + "chat_template": "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", + "clean_up_tokenization_spaces": false, + "eos_token": "<|endoftext|>", + "legacy": false, + "model_max_length": 131072, + "pad_token": "<|endoftext|>", + "padding_side": "right", + "processor_class": "Phi3VProcessor", + "sp_model_kwargs": {}, + "tokenizer_class": "LlamaTokenizer", + "unk_token": "", + "use_default_system_prompt": false +}