diff --git a/VERSION_INFO b/VERSION_INFO index 70426f852..b4f09dd42 100644 --- a/VERSION_INFO +++ b/VERSION_INFO @@ -1 +1 @@ -0.2.0-dev +0.2.0-dev \ No newline at end of file diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index 34f24083e..ad233fa2d 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -3,18 +3,21 @@ This folder contains the model builder for quickly creating optimized and quantized ONNX models within a few minutes that run with ONNX Runtime GenAI. # Contents - - [Current Support](#current-support) - - [Usage](#usage) - - [Full Usage](#full-usage) - - [Original PyTorch Model from Hugging Face](#original-pytorch-model-from-hugging-face) - - [Original PyTorch Model from Disk](#original-pytorch-model-from-disk) - - [Customized or Finetuned PyTorch Model](#customized-or-finetuned-pytorch-model) - - [GGUF Model](#gguf-model) - - [Extra Options](#extra-options) - - [Config Only](#config-only) - - [Unit Testing Models](#unit-testing-models) - - [Option 1: Use the model builder tool directly](#option-1-use-the-model-builder-tool-directly) - - [Option 2: Edit the config.json file](#option-2-edit-the-configjson-file-on-disk-and-then-run-the-model-builder-tool) +- [Current Support](#current-support) +- [Usage](#usage) + - [Full Usage](#full-usage) + - [Original PyTorch Model from Hugging Face](#original-pytorch-model-from-hugging-face) + - [Original PyTorch Model from Disk](#original-pytorch-model-from-disk) + - [Customized or Finetuned PyTorch Model](#customized-or-finetuned-pytorch-model) + - [GGUF Model](#gguf-model) + - [Extra Options](#extra-options) + - [Config Only](#config-only) + - [Exclude Embedding Layer](#exclude-embedding-layer) + - [Exclude Language Modeling Head](#exclude-language-modeling-head) + - [Unit Testing Models](#unit-testing-models) + - [Option 1: Use the model builder directly](#option-1-use-the-model-builder-directly) + - [Option 2: Edit the config.json file](#option-2-edit-the-configjson-file-on-disk-and-then-run-the-model-builder) +- [Design](#design) ## Current Support The tool currently supports the following model architectures. @@ -89,7 +92,7 @@ python3 builder.py -m model_name -o path_to_output_folder -p precision -e execut ``` To see all available options through `--extra_options`, please use the `help` commands in the `Full Usage` section above. -### Config Only +#### Config Only This scenario is for when you already have your optimized and/or quantized ONNX model and you need to create the config files to run with ONNX Runtime GenAI. ``` # From wheel: @@ -101,6 +104,28 @@ python3 builder.py -m model_name -o path_to_output_folder -p precision -e execut Afterwards, please open the `genai_config.json` file in the output folder and modify the fields as needed for your model. You should store your ONNX model in the output folder as well. +#### Exclude Embedding Layer +This scenario is for when you want to exclude the embedding layer from your ONNX model. + +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_embeds=true + +# From source: +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_embeds=true +``` + +#### Exclude Language Modeling Head +This scenario is for when you want to exclude the language modeling head from your ONNX model. + +``` +# From wheel: +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_lm_head=true + +# From source: +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files --extra_options exclude_lm_head=true +``` + ### Unit Testing Models This scenario is where your PyTorch model is already downloaded locally (either in the default Hugging Face cache directory or in a local folder on disk). If it is not already downloaded locally, here is an example of how you can download it. @@ -117,7 +142,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) tokenizer.save_pretrained(cache_dir) ``` -#### Option 1: Use the model builder tool directly +#### Option 1: Use the model builder directly This option is the simplest but it will download another copy of the PyTorch model onto disk to accommodate the change in the number of hidden layers. ``` # From wheel: @@ -127,11 +152,11 @@ python3 -m onnxruntime_genai.models.builder -m model_name -o path_to_output_fold python3 builder.py -m model_name -o path_to_output_folder -p precision -e execution_provider --extra_options num_hidden_layers=4 ``` -#### Option 2: Edit the config.json file on disk and then run the model builder tool +#### Option 2: Edit the config.json file on disk and then run the model builder 1. Navigate to where the PyTorch model and its associated files are saved on disk. 2. Modify `num_hidden_layers` in `config.json` to your desired target (e.g. 4 layers). -3. Run the below command for the model builder tool. +3. Run the below command for the model builder. ``` # From wheel: diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 16f864c38..e0c6d28aa 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -37,7 +37,14 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.model_type = config.architectures[0] self.io_dtype = io_dtype # {'fp16', 'fp32'} self.onnx_dtype = onnx_dtype # {"int4", "fp16", "fp32"} + + # EP-specific variables self.ep = ep + self.ep_attrs = { + "cpu": {}, + "cuda": {}, + "dml": {}, + } self.cache_dir = cache_dir self.filename = extra_options["filename"] if "filename" in extra_options else "model.onnx" @@ -52,15 +59,42 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # Map input names to their types and shapes self.input_names = ["input_ids", "attention_mask", "position_ids"] self.input_types = { - "input_ids": TensorProto.INT64, - "attention_mask": TensorProto.INT64, - "position_ids": TensorProto.INT64, + "input_ids": TensorProto.INT64, # For standard models + "attention_mask": TensorProto.INT64, # For standard models + "position_ids": TensorProto.INT64, # For standard models + "inputs_embeds": self.io_dtype, # For standard models where you want to remove the embedding layer from the model (note that `inputs_embeds` is written this way to match Hugging Face format) + "past_key_values.key": self.io_dtype, # For standard models (note that `past_key_values.key` is written this way to match Hugging Face format) + "past_key_values.value": self.io_dtype, # For standard models (note that `past_key_values.value` is written this way to match Hugging Face format) } self.input_shapes = { - "input_ids": ["batch_size", "sequence_length"], - "attention_mask": ["batch_size", "total_sequence_length"], - "position_ids": ["batch_size", "sequence_length"], + "input_ids": ["batch_size", "sequence_length"], # For standard models + "attention_mask": ["batch_size", "total_sequence_length"], # For standard models + "position_ids": ["batch_size", "sequence_length"], # For standard models + "inputs_embeds": ["batch_size", "sequence_length", self.hidden_size], # For standard models where you want to remove the embedding layer from the model (note that `inputs_embeds` is written this way to match Hugging Face format) + "past_key_values.key": ["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size], # For standard models (note that `past_key_values.key` is written this way to match Hugging Face format) + "past_key_values.value": ["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size], # For standard models (note that `past_key_values.value` is written this way to match Hugging Face format) + } + self.exclude_embeds = "exclude_embeds" in extra_options + if self.exclude_embeds: + self.input_names = [name.replace("input_ids", "inputs_embeds") for name in self.input_names] + + # Map output names to their types and shapes + self.output_names = ["logits"] + self.output_types = { + "hidden_states": self.io_dtype, # For standard models where you want to remove the language modeling head from the model (note that `hidden_states` is written this way to match Hugging Face format) + "logits": self.io_dtype, # For standard models + "present.key": self.io_dtype, # For standard models (note that `present.key` is written this way to match Hugging Face format) + "present.value": self.io_dtype, # For standard models (note that `present.value` is written this way to match Hugging Face format) + } + self.output_shapes = { + "hidden_states": ["batch_size", "sequence_length", self.hidden_size], # For standard models where you want to remove the language modeling head from the model (note that `hidden_states` is written this way to match Hugging Face format) + "logits": ["batch_size", "sequence_length", self.vocab_size], # For standard models + "present.key": ["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size], # For standard models (note that `present.key` is written this way to match Hugging Face format) + "present.value": ["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size], # For standard models (note that `present.value` is written this way to match Hugging Face format) } + self.exclude_lm_head = "exclude_lm_head" in extra_options + if self.exclude_lm_head: + self.output_names = [name.replace("logits", "hidden_states") for name in self.output_names] # Store names of Constant ops already created self.constants = set() @@ -122,17 +156,18 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): # Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.) self.attention_attrs = { "op_type": "MultiHeadAttention", # Attention op to use - "use_rotemb_in_gqa": False, # Use rotary embeddings within GroupQueryAttention (instead of a separate RotaryEmbedding op) + "use_rotemb_in_attn": False, # Use rotary embeddings within attention op (instead of a separate RotaryEmbedding op) "use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V) } - if ep == "cuda" and io_dtype == TensorProto.FLOAT16: + if self.ep == "cuda" and self.io_dtype == TensorProto.FLOAT16: + # Change model settings for GroupQueryAttention self.attention_attrs["op_type"] = "GroupQueryAttention" print("GroupQueryAttention (GQA) is used in this model. GQA is currently supported only for INT4 CUDA and FP16 CUDA.") self.attention_attrs["use_packed_matmul"] = self.num_attn_heads == self.num_kv_heads # GQA + Rot.Emb. does not require `position ids` as input - self.attention_attrs["use_rotemb_in_gqa"] = True + self.attention_attrs["use_rotemb_in_attn"] = True self.input_names.remove("position_ids") # MLP-specific variables @@ -180,7 +215,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): "num_key_value_heads": self.num_kv_heads, }, "eos_token_id": config.eos_token_id, - "pad_token_id": config.pad_token_id if hasattr(config, "pad_token_id") and config.pad_token_id != None else config.eos_token_id, + "pad_token_id": config.pad_token_id if hasattr(config, "pad_token_id") and config.pad_token_id is not None else config.eos_token_id, "type": self.model_type[ : self.model_type.find("For")].lower(), "vocab_size": self.vocab_size, }, @@ -202,9 +237,9 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): }, } - if self.ep == "cuda": - cuda_options = { "cuda" : { } } - genai_config["model"]["decoder"]["session_options"]["provider_options"].append(cuda_options) + if self.ep != "cpu": + ep_options = { self.ep : self.ep_attrs[self.ep] } + genai_config["model"]["decoder"]["session_options"]["provider_options"].append(ep_options) print(f"Saving GenAI config in {out_dir}") with open(os.path.join(out_dir,"genai_config.json"), "w") as f: @@ -335,23 +370,25 @@ def make_inputs_and_outputs(self): inputs.append(helper.make_tensor_value_info(name, dtype, shape=shape)) # Add model-specific outputs to list of model outputs - outputs = [ - helper.make_tensor_value_info("logits", self.io_dtype, shape=["batch_size", "sequence_length", self.vocab_size]) - ] + outputs = [] + for name in self.output_names: + dtype = self.output_types[name] + shape = self.output_shapes[name] + outputs.append(helper.make_tensor_value_info(name, dtype, shape=shape)) # Add KV cache to inputs and outputs for i in range(self.num_layers): # Add KV cache to inputs key_name = f"past_key_values.{i}.key" - inputs.append(helper.make_tensor_value_info(key_name, self.io_dtype, shape=["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size])) + inputs.append(helper.make_tensor_value_info(key_name, self.input_types["past_key_values.key"], shape=self.input_shapes["past_key_values.key"])) value_name = f"past_key_values.{i}.value" - inputs.append(helper.make_tensor_value_info(value_name, self.io_dtype, shape=["batch_size", self.num_kv_heads, "past_sequence_length", self.head_size])) + inputs.append(helper.make_tensor_value_info(value_name, self.input_types["past_key_values.value"], shape=self.input_shapes["past_key_values.value"])) # Add KV cache to outputs key_name = f"present.{i}.key" - outputs.append(helper.make_tensor_value_info(key_name, self.io_dtype, shape=["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size])) + outputs.append(helper.make_tensor_value_info(key_name, self.output_types["present.key"], shape=self.output_shapes["present.key"])) value_name = f"present.{i}.value" - outputs.append(helper.make_tensor_value_info(value_name, self.io_dtype, shape=["batch_size", self.num_kv_heads, "total_sequence_length", self.head_size])) + outputs.append(helper.make_tensor_value_info(value_name, self.output_types["present.value"], shape=self.output_shapes["present.value"])) self.inputs = inputs self.outputs = outputs @@ -573,6 +610,8 @@ def make_layernorm(self, layer_id, layernorm, skip, simple, location): output_0 = f"/model/layers.{layer_id}/{location}_layernorm/output_0" output_3 = f"/model/layers.{layer_id}/{location}_layernorm/output_3" + if self.layernorm_attrs["last_layernorm"] and self.exclude_lm_head: + output_0 = "hidden_states" outputs = [output_0, "", "", output_3] if skip and not self.layernorm_attrs["last_layernorm"] else [output_0] self.make_node(op_type, inputs=inputs, outputs=outputs, name=name, domain=("com.microsoft" if skip else None), **kwargs) @@ -829,14 +868,14 @@ def make_group_query_attention(self, name, **kwargs): kwargs["q_path"], kwargs["k_path"], kwargs["v_path"], kwargs.get("past_k", ""), kwargs.get("past_v", ""), kwargs.get("seqlens_k", ""), kwargs.get("total_seq_len", ""), - kwargs.get("cos_cache", ""), kwargs.get("sin_cache", "") + kwargs.get("cos_cache", ""), kwargs.get("sin_cache", ""), ] output = f"{name}/output_0" outputs = [output, kwargs.get("present_k", ""), kwargs.get("present_v", "")] self.make_node( "GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft", num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, local_window_size=self.window_size, - do_rotary=self.attention_attrs["use_rotemb_in_gqa"], rotary_interleaved=self.rotemb_attrs["interleaved"], + do_rotary=self.attention_attrs["use_rotemb_in_attn"], rotary_interleaved=self.rotemb_attrs["interleaved"], ) self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads]) @@ -923,7 +962,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): # Make RotaryEmbedding nodes cos_cache_name, sin_cache_name = "", "" - if self.attention_attrs["use_rotemb_in_gqa"]: + if self.attention_attrs["use_rotemb_in_attn"]: cos_cache_name, sin_cache_name = self.make_rotary_embedding_caches(attention.rotary_emb) else: q_rotary_name = f"/model/layers.{layer_id}/attn/q_rotary/RotaryEmbedding" @@ -938,7 +977,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): past_v = f"past_key_values.{layer_id}.value" present_k = f"present.{layer_id}.key" present_v = f"present.{layer_id}.value" - if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] != "GroupQueryAttention": + if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] == "MultiHeadAttention": k_input_to_attention = self.make_repeat_kv(layer_id, root_input=k_input_to_attention, past_kv=past_k, present_kv=present_k) v_input_to_attention = self.make_repeat_kv(layer_id, root_input=v_input_to_attention, past_kv=past_v, present_kv=present_v) past_k, past_v, present_k, present_v = "", "", "", "" @@ -1112,14 +1151,8 @@ def make_model(self, input_path): # Make inputs and outputs to ONNX model self.make_inputs_and_outputs() - # Make attention mask reformatting nodes - # - # 2D attention mask - # | - # attention mask reformatting subgraph - # | - # 4D causal attention mask - self.make_attention_mask_reformatting() + # Make pre-processing nodes + self.make_preprocessing_nodes() # Load weights of original model if input_path.endswith(".gguf"): @@ -1129,7 +1162,7 @@ def make_model(self, input_path): self.layernorm_attrs["add_offset"] = 0 # add offset already done for GGUF models else: # Load PyTorch model - extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir, "use_auth_token": True} + extra_kwargs = {"trust_remote_code": True} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir, "use_auth_token": True} model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **extra_kwargs) # Loop through model and map each module to ONNX/ORT ops @@ -1137,23 +1170,32 @@ def make_model(self, input_path): for module in model.modules(): if isinstance(module, torch.nn.Embedding) or (hasattr(model, "embedding") and module == model.embedding): # Checks (Hugging Face logic) or (GGUF logic) - # Embedding layer - print("Reading embedding layer") - self.make_embedding(module.weight.detach().numpy()) + if not self.exclude_embeds: + # Embedding layer + print("Reading embedding layer") + self.make_embedding(module.weight.detach().numpy()) + else: + # Exclude embedding layer from model + self.layernorm_attrs["root_input"] = "inputs_embeds" + self.layernorm_attrs["skip_input"] = "inputs_embeds" + elif module.__class__.__name__.endswith("DecoderLayer"): # Each decoder layer of model print(f"Reading decoder layer {self.layer_id}") self.make_layer(self.layer_id, module) self.layer_id += 1 + elif self.layer_id == self.num_layers and self.has_final_norm(module, model): # SkipLayerNorm after last decoder layer (MatMul --> SkipLayerNorm) print("Reading final norm") self.make_layernorm(self.layer_id, module, skip=True, simple=self.layernorm_attrs["simple"], location="final_norm") + elif (isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size) or (hasattr(model, "lm_head") and module == model.lm_head): # Checks (Hugging Face logic) or (GGUF logic) - # Language modeling head (SkipLayerNorm --> logits) - print("Reading LM head") - self.make_lm_head(module) + if not self.exclude_lm_head: + # Language modeling head (SkipLayerNorm --> logits) + print("Reading LM head") + self.make_lm_head(module) del model @@ -1165,10 +1207,21 @@ def has_final_norm(self, module, model): gguf_final_norm = hasattr(model, "final_norm") and module == model.final_norm return hf_norm or hf_final_layernorm or gguf_final_norm + def make_preprocessing_nodes(self): + self.make_attention_mask_reformatting() + # TODO: add make_position_ids_reformatting() here + def make_attention_mask_reformatting(self): if self.attention_attrs["op_type"] == "GroupQueryAttention": self.make_attention_mask_reformatting_for_gqa() - else: + elif self.attention_attrs["op_type"] == "MultiHeadAttention": + # Make attention mask reformatting nodes + # + # 2D attention mask + # | + # attention mask reformatting subgraph + # | + # 4D causal attention mask self.make_attention_mask_reformatting_2d_to_4d() def make_attention_mask_reformatting_2d_to_4d(self): @@ -1352,7 +1405,7 @@ def make_input_ids_subgraph(self, basename, past_key_gather_name): unsqueeze_9_inputs = [f"{unsqueeze_8_name}/output_0", "/model/constants/TensorProto.INT64/1D/1"] self.make_unsqueeze(unsqueeze_9_name, unsqueeze_9_inputs, dtype=self.io_dtype, shape=None) - expand_name = self.make_common_mask_reformat_subgraph(basename, root_input="input_ids", unsqueeze_for_concat=unsqueeze_3_name, unsqueeze_for_expand=unsqueeze_9_name, input_ids_subgraph=True) + expand_name = self.make_common_mask_reformat_subgraph(basename, root_input="input_ids" if not self.exclude_embeds else "inputs_embeds", unsqueeze_for_concat=unsqueeze_3_name, unsqueeze_for_expand=unsqueeze_9_name, input_ids_subgraph=True) return unsqueeze_6_name, expand_name def make_attention_mask_subgraph(self, basename, unsqueeze_for_concat): @@ -1427,9 +1480,9 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for # Expand shape_1_name = f"{basename}/Shape_1" - self.make_shape(shape_1_name, root_input, shape=[2]) + self.make_shape(shape_1_name, root_input, shape=[3] if self.exclude_embeds and input_ids_subgraph else [2]) shape_2_name = f"{basename}/Shape_2" - self.make_shape(shape_2_name, root_input, shape=[2]) + self.make_shape(shape_2_name, root_input, shape=[3] if self.exclude_embeds and input_ids_subgraph else [2]) gather_1_name = f"{basename}/Gather_1" gather_1_inputs = [f"{shape_1_name}/output_0", "/model/constants/TensorProto.INT64/0D/0"] self.make_gather(gather_1_name, gather_1_inputs, axis=0) @@ -1555,7 +1608,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): class MistralModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) - self.position_ids_name = f"{self.make_position_ids_reformatting()}/output_0" if not self.attention_attrs["use_rotemb_in_gqa"] else "position_ids" + self.position_ids_name = f"{self.make_position_ids_reformatting()}/output_0" if not self.attention_attrs["use_rotemb_in_attn"] else "position_ids" def make_attention(self, layer_id, attention, root_input, **kwargs): super().make_attention(layer_id, attention, root_input, position_ids=self.position_ids_name, **kwargs) @@ -1621,7 +1674,7 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid os.makedirs(cache_dir, exist_ok=True) # Load model config - extra_kwargs = {} if os.path.isdir(input_path) else {"cache_dir": cache_dir, "use_auth_token": True} + extra_kwargs = {"trust_remote_code": True} if os.path.isdir(input_path) else {"cache_dir": cache_dir, "use_auth_token": True} hf_name = input_path if os.path.isdir(input_path) else model_name config = AutoConfig.from_pretrained(hf_name, **extra_kwargs) @@ -1730,6 +1783,12 @@ def get_args(): The filename for each component will be '_.onnx' (ex: '_encoder.onnx', '_decoder.onnx'). config_only = Generate config and pre/post processing files only. Use this option when you already have your optimized and/or quantized ONNX model. + exclude_embeds = Remove embedding layer from your ONNX model. + Use this option when you want to remove the embedding layer from within your ONNX model. + Instead of `input_ids`, you will have `inputs_embeds` as the input to your ONNX model. + exclude_lm_head = Remove language modeling head from your ONNX model. + Use this option when you want to remove the language modeling head from within your ONNX model. + Instead of `logits`, you will have `hidden_states` as the output to your ONNX model. """), )