diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index f39c1d941..8dee988ff 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -11,6 +11,7 @@ This folder contains the model builder tool, which greatly accelerates creating - [Customized or Finetuned PyTorch Model](#customized-or-finetuned-pytorch-model) - [GGUF Model](#gguf-model) - [Extra Options](#extra-options) + - [Config Only](#config-only) - [Unit Testing Models](#unit-testing-models) - [Option 1: Use the model builder tool directly](#option-1-use-the-model-builder-tool-directly) - [Option 2: Edit the config.json file](#option-2-edit-the-configjson-file-on-disk-and-then-run-the-model-builder-tool) @@ -86,6 +87,18 @@ python3 builder.py -m model_name -o path_to_output_folder -p precision -e execut ``` To see all available options through `--extra_options`, please use the `help` commands in the `Full Usage` section above. +### Config Only +This scenario is for when you already have your optimized and/or quantized ONNX model and you need to create the config files to run with ONNX Runtime GenAI. +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options config_only=true + +# From source: +python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider -c cache_dir_for_hf_files --extra_options config_only=true +``` + +Afterwards, please open the `genai_config.json` file in the output folder and modify the fields as needed for your model. You should store your ONNX model in the output folder as well. + ### Unit Testing Models This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). If it is not already downloaded locally, here is an example of how you can download it. diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index ca5a521d3..1652d5778 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -25,7 +25,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.window_size = config.sliding_window if hasattr(config, "sliding_window") else -1 # default is -1 in GroupQueryAttention kernel self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size - self.num_kv_heads = config.num_key_value_heads + self.num_kv_heads = config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads self.num_attn_heads = config.num_attention_heads self.head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads self.num_layers = int(extra_options["num_hidden_layers"]) if "num_hidden_layers" in extra_options else config.num_hidden_layers @@ -102,12 +102,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # RotaryEmbedding-specific variables partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000 self.rotemb_attrs = { "create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings "partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings "num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) "rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) - "theta": config.rope_theta, # Base value if calculating cos/sin caches from scratch + "theta": rope_theta, # Base value if calculating cos/sin caches from scratch } # Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.) @@ -139,8 +140,8 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): "context_length": self.context_length, "decoder": { "session_options" : { - "provider_options" : [ - ] + "log_id": "onnxruntime-genai", + "provider_options" : [] }, "filename": self.filename, "head_size": self.head_size, @@ -1400,23 +1401,26 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid # Set input/output precision of ONNX model io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16 - # List architecture options in alphabetical order - if config.architectures[0] == "GemmaForCausalLM": - onnx_model = GemmaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) - elif config.architectures[0] == "LlamaForCausalLM": - onnx_model = LlamaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) - elif config.architectures[0] == "MistralForCausalLM": - onnx_model = MistralModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) - elif config.architectures[0] == "PhiForCausalLM": - onnx_model = PhiModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) - else: - raise NotImplementedError(f"The {hf_name} model is not currently supported.") + if "config_only" not in extra_options: + # List architecture options in alphabetical order + if config.architectures[0] == "GemmaForCausalLM": + onnx_model = GemmaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "LlamaForCausalLM": + onnx_model = LlamaModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "MistralForCausalLM": + onnx_model = MistralModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "PhiForCausalLM": + onnx_model = PhiModel(config, io_dtype, precision, execution_provider, cache_dir, extra_options) + else: + raise NotImplementedError(f"The {hf_name} model is not currently supported.") - # Make ONNX model - onnx_model.make_model(input_path) + # Make ONNX model + onnx_model.make_model(input_path) - # Save ONNX model - onnx_model.save_model(output_dir) + # Save ONNX model + onnx_model.save_model(output_dir) + else: + onnx_model = Model(config, io_dtype, precision, execution_provider, cache_dir, extra_options) # Make GenAI config onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir) @@ -1497,6 +1501,8 @@ def get_args(): filename = Filename for ONNX model (default is 'model.onnx'). For models with multiple components, each component is exported to its own ONNX model. The filename for each component will be '_.onnx' (ex: '_encoder.onnx', '_decoder.onnx'). + config_only = Generate config and pre/post processing files only. + Use this option when you already have your optimized and/or quantized ONNX model. """), )