From c61aaa6b2349b39ca63509914b4c02105b462a4a Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Tue, 10 Dec 2024 22:18:46 -0800 Subject: [PATCH] Add embeddings output to model builder (#1127) ### Description This PR adds support for outputting the last hidden state in addition to the logits in ONNX models. Users can run their models with ONNX Runtime GenAI and use the generator's `GetOutput` API to obtain the hidden states. C/C++: ```c std::unique_ptr embeddings = generator->GetOutput("hidden_states"); ``` C#: ```csharp using var embeddings = generator.GetOutput("hidden_states"); ``` Java: ```java Tensor embeddings = generator.getOutput("hidden_states"); ``` Python: ```python embeddings = generator.get_output("hidden_states") ``` ### Motivation and Context In SLMs and LLMs, the last hidden state represents a model's embeddings for a particular input before the language modeling head is applied. Generating embeddings for a model is a popular task. These embeddings can be used for many scenarios such as text classification, sequence labeling, information retrieval using [retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation), and more. This PR helps the following issues: - https://github.com/microsoft/onnxruntime/issues/20969 - https://github.com/microsoft/onnxruntime-genai/issues/442 - https://github.com/microsoft/onnxruntime-genai/discussions/474 - https://github.com/microsoft/onnxruntime-genai/discussions/713 --- src/python/py/models/README.md | 61 ++++--- src/python/py/models/builder.py | 163 ++++++++++-------- test/python/_test_utils.py | 2 +- test/python/test_onnxruntime_genai_api.py | 4 +- .../pipeline-model/genai_config.json | 6 +- 5 files changed, 132 insertions(+), 104 deletions(-) diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index 089160e66..94a9fafec 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -14,11 +14,12 @@ This folder contains the model builder for quickly creating optimized and quanti - [GGUF Model](#gguf-model) - [Extra Options](#extra-options) - [Config Only](#config-only) + - [Hugging Face Authentication](#hugging-face-authentication) - [Exclude Embedding Layer](#exclude-embedding-layer) - [Exclude Language Modeling Head](#exclude-language-modeling-head) - - [Enable Cuda Graph](#enable-cuda-graph) + - [Include Last Hidden States Output](#include-last-hidden-states-output) + - [Enable CUDA Graph](#enable-cuda-graph) - [Use 8 Bits Quantization in QMoE](#use-8-bits-quantization-in-qmoe) - - [Hugging Face Authentication](#hugging-face-authentication) - [Use QDQ Pattern for Quantization](#use-qdq-pattern-for-quantization) - [LoRA Models](#lora-models) - [Unit Testing Models](#unit-testing-models) @@ -30,12 +31,13 @@ This folder contains the model builder for quickly creating optimized and quanti The tool currently supports the following model architectures. +- ChatGLM - Gemma - LLaMA - Mistral +- Nemotron - Phi - Qwen -- Nemotron It is intended for supporting the latest, popular state-of-the-art models. @@ -141,6 +143,18 @@ python3 builder.py -m model_name -o path_to_output_folder -p precision -e execut Afterwards, please open the `genai_config.json` file in the output folder and modify the fields as needed for your model. You should store your ONNX model in the output folder as well. +#### Hugging Face Authentication + +This scenario is for when you need to disable the Hugging Face authentication or use a different authentication token than the one stored in [huggingface-cli login](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#huggingface-cli-login). + +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options hf_token=false + +# From source: +python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options hf_token=false +``` + #### Exclude Embedding Layer This scenario is for when you want to exclude the embedding layer from your ONNX model. @@ -165,65 +179,60 @@ python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o p python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_lm_head=true ``` -#### Enable Cuda Graph +#### Include Last Hidden States Output -This scenario is for when you want to enable cuda graph for your ONNX model. +This scenario is for when you want to include the last hidden states as an output to your ONNX model. ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options enable_cuda_graph=1 +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options include_hidden_states=true # From source: -python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options enable_cuda_graph=1 +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options include_hidden_states=true ``` -#### Use 8 Bits Quantization in QMoE +Note that this is the same as outputting embeddings since the last hidden states are also known as the embeddings. -This scenario is for when you want to use 8-bit quantization for MoE layers. Default is using 4-bit quantization. +#### Enable CUDA Graph + +This scenario is for when you want to enable CUDA graph for your ONNX model. ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_8bits_moe=1 +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options enable_cuda_graph=true # From source: -python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_8bits_moe=1 +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options enable_cuda_graph=true ``` -#### Hugging Face Authentication - -This scenario is for when you need to disable the Hugging Face authentication or use a different authentication token than the one stored in [huggingface-cli login](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#huggingface-cli-login). - -Possible values : +#### Use 8 Bits Quantization in QMoE -- hf_token=False -- hf_token= +This scenario is for when you want to use 8-bit quantization for MoE layers. Default is using 4-bit quantization. ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options hf_token=False +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_8bits_moe=true # From source: -python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options hf_token=False +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_8bits_moe=true ``` #### Use QDQ Pattern for Quantization -This scenario is for when you want to use the QDQ pattern (DequantizeLinear + MatMul) instead of the MatMulNBits operator when quantizing the model to 4 bits. +This scenario is for when you want to use the QDQ pattern when quantizing the model to 4 bits. ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=1 +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=true # From source: -python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=1 +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options use_qdq=true ``` #### LoRA Models This scenario is where you have a finetuned model with LoRA adapters and your model can be loaded in the Hugging Face style via [PEFT](https://github.com/huggingface/peft). -- path_to_local_folder_on_disk = location where base_model's weights are present - ``` # From wheel: python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p fp16 -e execution_provider -c cache_dir_to_store_temp_files --extra_options adapter_path=path_to_adapter_files @@ -232,6 +241,8 @@ python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o p python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p fp16 -e execution_provider -c cache_dir_to_store_temp_files --extra_options adapter_path=path_to_adapter_files ``` +Base weights should be located in `path_to_local_folder_on_disk` and adapter weights should be located in `path_to_adapter_files`. + ### Unit Testing Models This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). If it is not already downloaded locally, here is an example of how you can download it. diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 1dd88ccf2..e027fabda 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -86,7 +86,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "past_key_values.key": ["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size], # For standard models (note that `past_key_values.key` is written this way to match Hugging Face format) "past_key_values.value": ["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size], # For standard models (note that `past_key_values.value` is written this way to match Hugging Face format) } - self.exclude_embeds = "exclude_embeds" in extra_options + self.exclude_embeds = extra_options.get("exclude_embeds", False) if self.exclude_embeds: self.input_names = [name.replace("input_ids", "inputs_embeds") for name in self.input_names] @@ -104,9 +104,12 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "present.key": ["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size], # For standard models (note that `present.key` is written this way to match Hugging Face format) "present.value": ["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size], # For standard models (note that `present.value` is written this way to match Hugging Face format) } - self.exclude_lm_head = "exclude_lm_head" in extra_options + self.exclude_lm_head = extra_options.get("exclude_lm_head", False) + self.include_hidden_states = extra_options.get("include_hidden_states", False) if self.exclude_lm_head: self.output_names = [name.replace("logits", "hidden_states") for name in self.output_names] + elif self.include_hidden_states: + self.output_names = ["hidden_states"] + self.output_names # Store names of nodes already created self.node_names = set() @@ -310,11 +313,21 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): config = GenerationConfig.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=True, **extra_kwargs) except: config = AutoConfig.from_pretrained(model_name_or_path, token=self.hf_token, trust_remote_code=True, **extra_kwargs) + inputs = dict(zip(self.input_names, self.input_names)) inputs.update({ "past_key_names": "past_key_values.%d.key", "past_value_names": "past_key_values.%d.value", }) + outputs = dict(zip(self.output_names, self.output_names)) + outputs.update({ + "present_key_names": "present.%d.key", + "present_value_names": "present.%d.value", + }) + if "hidden_states" in outputs: + # Remove 'hidden_states' from 'outputs' entry in config since ORT GenAI doesn't use it + del outputs["hidden_states"] + genai_config = { "model": { "bos_token_id": config.bos_token_id if hasattr(config, "bos_token_id") else 1, # config.bos_token_id not present in ChatGLM model configs. @@ -328,11 +341,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): "head_size": self.head_size, "hidden_size": self.hidden_size, "inputs": inputs, - "outputs": { - "logits": "logits", - "present_key_names": "present.%d.key", - "present_value_names": "present.%d.value", - }, + "outputs": outputs, "num_attention_heads": self.num_attn_heads, "num_hidden_layers": self.num_layers, "num_key_value_heads": self.num_kv_heads, @@ -364,7 +373,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): ep_options = { self.ep : self.ep_attrs[self.ep] } genai_config["model"]["decoder"]["session_options"]["provider_options"].append(ep_options) - if self.extra_options.get("prompt_templates", "0") == "1": + if self.extra_options.get("include_prompt_templates", False): prompt_templates = self._get_prompt_templates(model_name_or_path, extra_kwargs) if prompt_templates is not None: genai_config["model"]["prompt_templates"] = prompt_templates @@ -1012,7 +1021,7 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location): output_0 = f"/model/layers.{layer_id}/{location}_layernorm/output_0" output_3 = f"/model/layers.{layer_id}/{location}_layernorm/output_3" - if self.layernorm_attrs["last_layernorm"] and self.exclude_lm_head: + if self.layernorm_attrs["last_layernorm"] and (self.include_hidden_states or self.exclude_lm_head): output_0 = "hidden_states" outputs = [output_0, "", "", output_3] if skip and not self.layernorm_attrs["last_layernorm"] else [output_0] @@ -2026,17 +2035,7 @@ def make_model(self, input_path): from onnxruntime_genai.models.quantized_model import QuantModel q_size = self.num_attn_heads * self.head_size kv_size = self.num_kv_heads * self.head_size - model = QuantModel.from_pretrained( - self.quant_type, - input_path, - self.quant_attrs["bits"], - self.quant_attrs["group_size"], - self.quant_attrs["use_g_idx"], - q_size, - kv_size, - self.intermediate_size, - self.num_layers, - ) + model = QuantModel.from_pretrained(self.quant_type, input_path, self.quant_attrs["bits"], self.quant_attrs["group_size"], self.quant_attrs["use_g_idx"], q_size, kv_size, self.intermediate_size, self.num_layers) else: # Load PyTorch model extra_kwargs = {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {} @@ -2665,43 +2664,6 @@ def make_mlp_proj(self, layer_id, mlp, root_input): super().make_mlp_unpacked(layer_id, mlp, root_input) super().make_mlp_proj(layer_id, mlp, root_input) -class NemotronModel(LlamaModel): - def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): - super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) - self.layernorm_attrs["simple"] = False - self.layernorm_attrs["add_offset"] = 1 - self.rotemb_attrs["rotary_embedding_dim"] = int(self.head_size * self.rotemb_attrs["partial_rotary_factor"]) - - def make_mlp_proj(self, layer_id, mlp, root_input): - # Make nodes for the MLP subgraph - # - # root_input - # | - # UpProjMatMul - # | - # ActFunc - # | - # DownProjMatMul - - up_basename = f"/model/layers.{layer_id}/mlp/up_proj/MatMul" - up_name = self.make_matmul(mlp.up_proj, up_basename, root_input) - - act_fn_name = self.make_activation(layer_id, root_input=f"{up_name}/output_0") - - # Make output MatMul node - down_basename = f"/model/layers.{layer_id}/mlp/down_proj/MatMul" - down_name = self.make_matmul(mlp.down_proj, down_basename, f"{act_fn_name}/output_0") - - # Assign output 0 of previous MatMul as skip input to next SkipLayerNorm - self.layernorm_attrs["skip_input"] = f"{down_name}/output_0" - - def make_attention(self, layer_id, attention, root_input, **kwargs): - attention.rotary_emb = type("RotaryEmbedding", (object,), {'content':{}})() - return super().make_attention(layer_id, attention, root_input, **kwargs) - - def make_rotary_embedding(self, rotemb, name, root_input, **kwargs): - num_heads = self.num_kv_heads if "k_rotary" in name else self.num_attn_heads - super().make_rotary_embedding(rotemb, name, root_input, num_heads=num_heads, rotary_embedding_dim=self.rotemb_attrs["rotary_embedding_dim"], **kwargs) class Phi3Mini128KModel(Phi3Mini4KModel): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): @@ -2772,6 +2734,7 @@ def make_rotary_embedding_caches(self, rotemb, **kwargs): return cos_cache_name, sin_cache_name + class Phi3Small8KModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -3012,6 +2975,45 @@ def make_layer(self, layer_id, layer): self.layernorm_attrs["last_layernorm"] = True +class NemotronModel(LlamaModel): + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + self.layernorm_attrs["simple"] = False + self.layernorm_attrs["add_offset"] = 1 + self.rotemb_attrs["rotary_embedding_dim"] = int(self.head_size * self.rotemb_attrs["partial_rotary_factor"]) + + def make_mlp_proj(self, layer_id, mlp, root_input): + # Make nodes for the MLP subgraph + # + # root_input + # | + # UpProjMatMul + # | + # ActFunc + # | + # DownProjMatMul + + up_basename = f"/model/layers.{layer_id}/mlp/up_proj/MatMul" + up_name = self.make_matmul(mlp.up_proj, up_basename, root_input) + + act_fn_name = self.make_activation(layer_id, root_input=f"{up_name}/output_0") + + # Make output MatMul node + down_basename = f"/model/layers.{layer_id}/mlp/down_proj/MatMul" + down_name = self.make_matmul(mlp.down_proj, down_basename, f"{act_fn_name}/output_0") + + # Assign output 0 of previous MatMul as skip input to next SkipLayerNorm + self.layernorm_attrs["skip_input"] = f"{down_name}/output_0" + + def make_attention(self, layer_id, attention, root_input, **kwargs): + attention.rotary_emb = type("RotaryEmbedding", (object,), {'content':{}})() + return super().make_attention(layer_id, attention, root_input, **kwargs) + + def make_rotary_embedding(self, rotemb, name, root_input, **kwargs): + num_heads = self.num_kv_heads if "k_rotary" in name else self.num_attn_heads + super().make_rotary_embedding(rotemb, name, root_input, num_heads=num_heads, rotary_embedding_dim=self.rotemb_attrs["rotary_embedding_dim"], **kwargs) + + class ChatGLMModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) @@ -3045,7 +3047,7 @@ def check_extra_options(kv_pairs): """ Check key-value pairs and set values correctly """ - bools = ["int4_is_symmetric", "use_qdq", "use_8bits_moe", "enable_cuda_graph"] + bools = ["int4_is_symmetric", "exclude_embeds", "exclude_lm_head", "include_hidden_states", "enable_cuda_graph", "use_8bits_moe", "use_qdq", "include_prompt_templates"] for key in bools: if key in kv_pairs: if kv_pairs[key] in {"false", "False", "0"}: @@ -3061,6 +3063,11 @@ def check_extra_options(kv_pairs): op_types_to_quantize += (op_type, ) kv_pairs["int4_op_types_to_quantize"] = op_types_to_quantize + if "exclude_lm_head" in kv_pairs and "include_hidden_states" in kv_pairs: + # 'exclude_lm_head' is for when 'hidden_states' are outputted and 'logits' are not outputted + # 'include_hidden_states' is for when 'hidden_states' are outputted and 'logits' are outputted + raise ValueError(f"Both 'exclude_lm_head' and 'include_hidden_states' cannot be used together. Please use only one of them at once.") + def parse_extra_options(kv_items): """ @@ -3116,7 +3123,11 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid if "config_only" not in extra_options: # List architecture options in alphabetical order - if config.architectures[0] == "GemmaForCausalLM": + if config.architectures[0] == "ChatGLMForConditionalGeneration" or config.architectures[0] == "ChatGLMModel": + # Quantized ChatGLM model has ChatGLMForConditionalGeneration as architecture whereas HF model as the latter + config.hidden_act = "swiglu" + onnx_model = ChatGLMModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "GemmaForCausalLM": onnx_model = GemmaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "Gemma2ForCausalLM": onnx_model = Gemma2Model(config, io_dtype, precision, execution_provider, cache_dir, extra_options) @@ -3124,6 +3135,8 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid onnx_model = LlamaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "MistralForCausalLM": onnx_model = MistralModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "NemotronForCausalLM": + onnx_model = NemotronModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "PhiForCausalLM": onnx_model = PhiModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "Phi3ForCausalLM" and config.max_position_embeddings == 4096: @@ -3146,12 +3159,6 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid onnx_model = Phi3VModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "Qwen2ForCausalLM": onnx_model = QwenModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) - elif config.architectures[0] == "NemotronForCausalLM": - onnx_model = NemotronModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) - elif config.architectures[0] == "ChatGLMForConditionalGeneration" or config.architectures[0] == "ChatGLMModel": - # Quantized ChatGLM model has ChatGLMForConditionalGeneration as architecture whereas HF model as the latter - config.hidden_act = "swiglu" - onnx_model = ChatGLMModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) else: raise NotImplementedError(f"The {hf_name} model is not currently supported.") @@ -3249,21 +3256,31 @@ def get_args(): The filename for each component will be '_.onnx' (ex: '_encoder.onnx', '_decoder.onnx'). config_only = Generate config and pre/post processing files only. Use this option when you already have your optimized and/or quantized ONNX model. + hf_token = false/token: Use this to manage authentication with Hugging Face. + Default behavior is to use the authentication token stored by `huggingface-cli login`. + If false, authentication with Hugging Face will be disabled. + If token, you can provide a custom authentication token that differs from the one stored in your environment. + If you have already authenticated via `huggingface-cli login`, you do not need to use this flag because Hugging Face has already stored your authentication token for you. exclude_embeds = Remove embedding layer from your ONNX model. Use this option when you want to remove the embedding layer from within your ONNX model. Instead of `input_ids`, you will have `inputs_embeds` as the input to your ONNX model. exclude_lm_head = Remove language modeling head from your ONNX model. Use this option when you want to remove the language modeling head from within your ONNX model. Instead of `logits`, you will have `hidden_states` as the output to your ONNX model. - enable_cuda_graph = 1: The model can use CUDA graph capture for CUDA execution provider. If enabled, all nodes being placed on the CUDA EP - is the prerequisite for the CUDA graph to be used correctly. It is not guaranteed that cuda graph be enabled as it depends on the model - and the graph structure. - use_8bits_moe = 1: Use 8-bit quantization for MoE layers. Default is using 4-bit quantization. - hf_token = false/token: Use this to disable authentication with Hugging Face or provide a custom authentication token that differs from the one stored in your environment. Default behavior is to use the authentication token stored by `huggingface-cli login`. - If you have already authenticated via `huggingface-cli login`, you do not need to use this flag because Hugging Face has already stored your authentication token for you. - use_qdq = 1: Use the QDQ decomposition for quantized MatMul instead of the MatMulNBits operator. + include_hidden_states = Include hidden states as output from your ONNX model. + Use this option when you want to have the hidden states as an output from your ONNX model. + In addition to `logits`, you will have `hidden_states` as an output to your ONNX model. + enable_cuda_graph = Enable CUDA graph capture during inference. Default is false. + If enabled, all nodes being placed on the CUDA EP is the prerequisite for the CUDA graph to be used correctly. + It is not guaranteed that CUDA graph be enabled as it depends on the model and the graph structure. + use_8bits_moe = Use 8-bit quantization for MoE layers. Default is false. + If true, the QMoE op will use 4-bit quantization. If false, the QMoE op will use 8-bits quantization. + use_qdq = Use the QDQ decomposition for ops. + Use this option when you want to use quantize-dequantize ops. For example, you will have a quantized MatMul op instead of the MatMulNBits op. adapter_path = Path to folder on disk containing the adapter files (adapter_config.json and adapter model weights). - prompt_templates = 1: Include per-role prompt templates in the GenAI config file. Default is 0 (not to include). + Use this option for LoRA models. + include_prompt_templates = Include prompt templates in the GenAI config file. Default is false. + Use this option to include per-role prompt templates in the `genai_config.json` file. """), ) diff --git a/test/python/_test_utils.py b/test/python/_test_utils.py index 808f8930e..0a44fa0e2 100644 --- a/test/python/_test_utils.py +++ b/test/python/_test_utils.py @@ -104,7 +104,7 @@ def download_model(model_name, input_path, output_path, precision, device, one_l device, ] - extra_options = ["--extra_options"] + extra_options = ["--extra_options", "include_hidden_states=true"] if device == "cpu" and precision == "int4": extra_options += ["int4_accuracy_level=4"] if one_layer: diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py index 0626d7fda..5a66b5735 100644 --- a/test/python/test_onnxruntime_genai_api.py +++ b/test/python/test_onnxruntime_genai_api.py @@ -395,14 +395,14 @@ def _split(onnx_model_path: os.PathLike, output_dir: os.PathLike): for kv in ["key", "value"] for i in range(num_layers) ], - [f"/model/layers.{num_layers}/final_norm_layernorm/output_0"] + ["hidden_states"] + [ f"present.{i}.{kv}" for kv in ["key", "value"] for i in range(num_layers) ], ), - ([f"/model/layers.{num_layers}/final_norm_layernorm/output_0"], ["logits"]), + ([f"hidden_states"], ["logits"]), ] for i, split_name in enumerate(["embeds", "transformer", "lm_head"]): diff --git a/test/test_models/pipeline-model/genai_config.json b/test/test_models/pipeline-model/genai_config.json index 060edf05d..0331092f8 100644 --- a/test/test_models/pipeline-model/genai_config.json +++ b/test/test_models/pipeline-model/genai_config.json @@ -45,7 +45,7 @@ "past_key_values.0.value" ], "outputs": [ - "/model/layers.1/final_norm_layernorm/output_0", + "hidden_states", "present.0.key", "present.0.value" ], @@ -63,7 +63,7 @@ "language_model_head": { "filename": "lm_head.onnx", "inputs": [ - "/model/layers.1/final_norm_layernorm/output_0" + "hidden_states" ], "outputs": [ "logits" @@ -93,4 +93,4 @@ "top_k": 1, "top_p": 1.0 } -} \ No newline at end of file +}