diff --git a/.gitignore b/.gitignore index 60b60827f..d42e707d4 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ /test/test_models/* /cache_models /onnxruntime-linux-x64-* -/*.csv +*.csv .idea cache_dir example-models diff --git a/src/python/py/models/README.md b/src/python/py/models/README.md index 0fdd2c818..34f24083e 100644 --- a/src/python/py/models/README.md +++ b/src/python/py/models/README.md @@ -62,10 +62,10 @@ python3 builder.py -m model_name -o path_to_output_folder -p precision -e execut This scenario is where your PyTorch model has been customized or finetuned for one of the currently supported model architectures and your model can be loaded in Hugging Face. ``` # From wheel: -python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider +python3 -m onnxruntime_genai.models.builder -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files # From source: -python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider +python3 builder.py -i path_to_local_folder_on_disk -o path_to_output_folder -p precision -e execution_provider -c cache_dir_to_store_temp_files ``` ### GGUF Model diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 740c7ca40..16f864c38 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -19,6 +19,7 @@ import os import textwrap + class Model: def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.context_length = config.max_position_embeddings @@ -48,7 +49,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.value_infos = [] self.nodes = [] - # Map input names to input shapes + # Map input names to their types and shapes + self.input_names = ["input_ids", "attention_mask", "position_ids"] + self.input_types = { + "input_ids": TensorProto.INT64, + "attention_mask": TensorProto.INT64, + "position_ids": TensorProto.INT64, + } self.input_shapes = { "input_ids": ["batch_size", "sequence_length"], "attention_mask": ["batch_size", "total_sequence_length"], @@ -105,19 +112,35 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000 self.rotemb_attrs = { "create_rotary_embedding_caches": True, # Create cos/sin caches for rotary embeddings + "theta": rope_theta, # Base value if calculating cos/sin caches from scratch "partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings + "interleaved": 0, # Interleave the rotary embeddings (e.g. [0, 0, 0, 1, 1, 1] to [0, 1, 0, 1, 0, 1], RotaryEmbedding kernel expects a default value of 0) "num_heads": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) "rotary_embedding_dim": 0, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) - "theta": rope_theta, # Base value if calculating cos/sin caches from scratch } # Attention-specific variables (MHA, GQA, GQA + Rot.Emb., etc.) self.attention_attrs = { - "op_type": "MultiHeadAttention", # Attention op to use - "use_gqa": ep == "cuda" and io_dtype == TensorProto.FLOAT16 # Check if GroupQueryAttention can be used + "op_type": "MultiHeadAttention", # Attention op to use + "use_rotemb_in_gqa": False, # Use rotary embeddings within GroupQueryAttention (instead of a separate RotaryEmbedding op) + "use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V) } - if self.attention_attrs["use_gqa"]: + if ep == "cuda" and io_dtype == TensorProto.FLOAT16: self.attention_attrs["op_type"] = "GroupQueryAttention" + print("GroupQueryAttention (GQA) is used in this model. GQA is currently supported only for INT4 CUDA and FP16 CUDA.") + + self.attention_attrs["use_packed_matmul"] = self.num_attn_heads == self.num_kv_heads + + # GQA + Rot.Emb. does not require `position ids` as input + self.attention_attrs["use_rotemb_in_gqa"] = True + self.input_names.remove("position_ids") + + # MLP-specific variables + self.mlp_attrs = { + "use_proj": True, # Use projection style for MLP (GateProj/UpProj/DownProj) + "use_fc": False, # Use fully-connected style for MLP (FC1/FC2) + "output_0": "", # Output 0 for MLP layer + } # Quantization-specific variables (INT4, INT8, etc.) self.quant_attrs = { @@ -129,7 +152,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): config = GenerationConfig.from_pretrained(model_name_or_path, **extra_kwargs) - inputs = dict(zip(self.input_shapes.keys(), self.input_shapes.keys())) + inputs = dict(zip(self.input_names, self.input_names)) inputs.update({ "past_key_names": "past_key_values.%d.key", "past_value_names": "past_key_values.%d.value", @@ -238,6 +261,7 @@ def save_model(self, out_dir): if os.path.exists(data_path): print(f"Overwriting {data_path}") os.remove(data_path) + save_model( model, out_path, @@ -305,9 +329,10 @@ def make_graph(self, *args, doc_string=None, **kwargs): def make_inputs_and_outputs(self): # Add model-specific inputs to list of model inputs inputs = [] - for name in self.model_inputs: + for name in self.input_names: + dtype = self.input_types[name] shape = self.input_shapes[name] - inputs.append(helper.make_tensor_value_info(name, TensorProto.INT64, shape=shape)) + inputs.append(helper.make_tensor_value_info(name, dtype, shape=shape)) # Add model-specific outputs to list of model outputs outputs = [ @@ -474,9 +499,13 @@ def make_matmul_fp16_or_fp32(self, matmul, name, root_input, **kwargs): # self.make_node("MatMulNBits", inputs=[root_input, weight, scales], outputs=[output], name=name) # self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) - # TODO: make packed QKV MatMul - # def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, name, root_input, **kwargs): - # pass + def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, name, root_input, **kwargs): + # N = num_heads * head_size, H = hidden_size + # Combine 3 Matmuls of shape NxH into 1 packed MatMul of shape 3NxH + # Note: Packed MatMul is of shape 3NxH instead of Hx3N because `make_matmul` will apply a transpose before saving + N, H = q_matmul.shape + matmul = np.stack((q_matmul.transpose(), k_matmul.transpose(), v_matmul.transpose()), axis=1).reshape(H, 3*N).transpose() + self.make_matmul(matmul, name, root_input, **kwargs) def make_add_bias(self, add, name, root_input, **kwargs): bias = name[1:].replace("/", ".") + ".bias" @@ -492,6 +521,11 @@ def make_add_bias(self, add, name, root_input, **kwargs): else: self.make_add(name, add_bias_inputs, dtype=self.io_dtype, shape=shape) + def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs): + # Combine 3 Adds of shape H into 1 packed Add of shape 3H + add = np.stack((q_add, k_add, v_add), axis=0).flatten() + self.make_add_bias(add, name, root_input, **kwargs) + def make_embedding(self, embedding): weight = "model.embed_tokens.weight" self.make_external_tensor(embedding.astype(self.to_numpy_dtype[self.io_dtype]), weight) @@ -587,7 +621,7 @@ def make_rotary_embedding(self, rotemb, name, root_input, **kwargs): inputs = [root_input, kwargs.pop("position_ids"), cos_cache_name, sin_cache_name] output = f"{name}/output_0" - self.make_node("RotaryEmbedding", inputs=inputs, outputs=[output], name=name, domain="com.microsoft", interleaved=0, **kwargs) + self.make_node("RotaryEmbedding", inputs=inputs, outputs=[output], name=name, domain="com.microsoft", interleaved=self.rotemb_attrs["interleaved"], **kwargs) self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * (self.num_kv_heads if "k_rotary" in name else self.num_attn_heads)]) # TODO: This function and any corresponding changes to support it are temporary until ORT supports GQA for CPU @@ -795,10 +829,15 @@ def make_group_query_attention(self, name, **kwargs): kwargs["q_path"], kwargs["k_path"], kwargs["v_path"], kwargs.get("past_k", ""), kwargs.get("past_v", ""), kwargs.get("seqlens_k", ""), kwargs.get("total_seq_len", ""), + kwargs.get("cos_cache", ""), kwargs.get("sin_cache", "") ] output = f"{name}/output_0" outputs = [output, kwargs.get("present_k", ""), kwargs.get("present_v", "")] - self.make_node("GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft", num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, local_window_size=self.window_size) + self.make_node( + "GroupQueryAttention", inputs=inputs, outputs=outputs, name=name, domain="com.microsoft", + num_heads=self.num_attn_heads, kv_num_heads=self.num_kv_heads, local_window_size=self.window_size, + do_rotary=self.attention_attrs["use_rotemb_in_gqa"], rotary_interleaved=self.rotemb_attrs["interleaved"], + ) self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.head_size * self.num_attn_heads]) def make_attention(self, layer_id, attention, root_input, **kwargs): @@ -841,60 +880,75 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): v_input_to_attention = "" # Make MatMul nodes - q_matmul_name = f"/model/layers.{layer_id}/attn/q_proj/MatMul" - self.make_matmul(attention.q_proj.weight.detach().numpy(), q_matmul_name, root_input) - q_input_to_attention = f"{q_matmul_name}/output_0" - k_matmul_name = f"/model/layers.{layer_id}/attn/k_proj/MatMul" - self.make_matmul(attention.k_proj.weight.detach().numpy(), k_matmul_name, root_input) - k_input_to_attention = f"{k_matmul_name}/output_0" - v_matmul_name = f"/model/layers.{layer_id}/attn/v_proj/MatMul" - self.make_matmul(attention.v_proj.weight.detach().numpy(), v_matmul_name, root_input) - v_input_to_attention = f"{v_matmul_name}/output_0" + if self.attention_attrs["use_packed_matmul"]: + # Combine 3 MatMuls into 1 packed MatMul + qkv_matmul_name = f"/model/layers.{layer_id}/attn/qkv_proj/MatMul" + self.make_packed_matmul(attention.q_proj.weight.detach().numpy(), attention.k_proj.weight.detach().numpy(), attention.v_proj.weight.detach().numpy(), qkv_matmul_name, root_input) + q_input_to_attention = f"{qkv_matmul_name}/output_0" + else: + q_matmul_name = f"/model/layers.{layer_id}/attn/q_proj/MatMul" + self.make_matmul(attention.q_proj.weight.detach().numpy(), q_matmul_name, root_input) + q_input_to_attention = f"{q_matmul_name}/output_0" + k_matmul_name = f"/model/layers.{layer_id}/attn/k_proj/MatMul" + self.make_matmul(attention.k_proj.weight.detach().numpy(), k_matmul_name, root_input) + k_input_to_attention = f"{k_matmul_name}/output_0" + v_matmul_name = f"/model/layers.{layer_id}/attn/v_proj/MatMul" + self.make_matmul(attention.v_proj.weight.detach().numpy(), v_matmul_name, root_input) + v_input_to_attention = f"{v_matmul_name}/output_0" # Make Add nodes (if bias exists) q_bias_exists = attention.q_proj.bias is not None k_bias_exists = attention.k_proj.bias is not None v_bias_exists = attention.v_proj.bias is not None + all_bias_exists = q_bias_exists and k_bias_exists and v_bias_exists - if q_bias_exists: - q_add_name = f"/model/layers.{layer_id}/attn/q_proj/Add" - self.make_add_bias(attention.q_proj.bias.detach().numpy(), q_add_name, root_input=f"{q_matmul_name}/output_0") - q_input_to_attention = f"{q_add_name}/output_0" - if k_bias_exists: - k_add_name = f"/model/layers.{layer_id}/attn/k_proj/Add" - self.make_add_bias(attention.k_proj.bias.detach().numpy(), k_add_name, root_input=f"{k_matmul_name}/output_0") - k_input_to_attention = f"{k_add_name}/output_0" - if v_bias_exists: - v_add_name = f"/model/layers.{layer_id}/attn/v_proj/Add" - self.make_add_bias(attention.v_proj.bias.detach().numpy(), v_add_name, root_input=f"{v_matmul_name}/output_0") - v_input_to_attention = f"{v_add_name}/output_0" + if all_bias_exists and self.attention_attrs["use_packed_matmul"]: + # Combine 3 Adds into 1 packed Add + qkv_add_name = f"/model/layers.{layer_id}/attn/qkv_proj/Add" + self.make_packed_add(attention.q_proj.bias.detach().numpy(), attention.k_proj.bias.detach().numpy(), attention.v_proj.bias.detach().numpy(), qkv_add_name, root_input=q_input_to_attention) + q_input_to_attention = f"{qkv_add_name}/output_0" + else: + if q_bias_exists: + q_add_name = f"/model/layers.{layer_id}/attn/q_proj/Add" + self.make_add_bias(attention.q_proj.bias.detach().numpy(), q_add_name, root_input=q_input_to_attention) + q_input_to_attention = f"{q_add_name}/output_0" + if k_bias_exists: + k_add_name = f"/model/layers.{layer_id}/attn/k_proj/Add" + self.make_add_bias(attention.k_proj.bias.detach().numpy(), k_add_name, root_input=k_input_to_attention) + k_input_to_attention = f"{k_add_name}/output_0" + if v_bias_exists: + v_add_name = f"/model/layers.{layer_id}/attn/v_proj/Add" + self.make_add_bias(attention.v_proj.bias.detach().numpy(), v_add_name, root_input=v_input_to_attention) + v_input_to_attention = f"{v_add_name}/output_0" # Make RotaryEmbedding nodes - q_rotary_name = f"/model/layers.{layer_id}/attn/q_rotary/RotaryEmbedding" - q_rotary_input = f"{q_matmul_name if not q_bias_exists else q_add_name}/output_0" - self.make_rotary_embedding(attention.rotary_emb, q_rotary_name, q_rotary_input, position_ids=kwargs.get("position_ids", "position_ids")) - q_input_to_attention = f"{q_rotary_name}/output_0" - - k_rotary_name = f"/model/layers.{layer_id}/attn/k_rotary/RotaryEmbedding" - k_rotary_input = f"{k_matmul_name if not k_bias_exists else k_add_name}/output_0" - self.make_rotary_embedding(attention.rotary_emb, k_rotary_name, k_rotary_input, position_ids=kwargs.get("position_ids", "position_ids")) - k_input_to_attention = f"{k_rotary_name}/output_0" + cos_cache_name, sin_cache_name = "", "" + if self.attention_attrs["use_rotemb_in_gqa"]: + cos_cache_name, sin_cache_name = self.make_rotary_embedding_caches(attention.rotary_emb) + else: + q_rotary_name = f"/model/layers.{layer_id}/attn/q_rotary/RotaryEmbedding" + self.make_rotary_embedding(attention.rotary_emb, q_rotary_name, root_input=q_input_to_attention, position_ids=kwargs.get("position_ids", "position_ids")) + q_input_to_attention = f"{q_rotary_name}/output_0" + k_rotary_name = f"/model/layers.{layer_id}/attn/k_rotary/RotaryEmbedding" + self.make_rotary_embedding(attention.rotary_emb, k_rotary_name, root_input=k_input_to_attention, position_ids=kwargs.get("position_ids", "position_ids")) + k_input_to_attention = f"{k_rotary_name}/output_0" # Make repeat KV nodes (TODO: remove once ORT supports GQA for CPU) past_k = f"past_key_values.{layer_id}.key" past_v = f"past_key_values.{layer_id}.value" present_k = f"present.{layer_id}.key" present_v = f"present.{layer_id}.value" - if self.num_attn_heads != self.num_kv_heads and not self.attention_attrs['use_gqa']: - k_input_to_attention = self.make_repeat_kv(layer_id, k_input_to_attention, past_k, present_k) - v_input_to_attention = self.make_repeat_kv(layer_id, v_input_to_attention, past_v, present_v) + if self.num_attn_heads != self.num_kv_heads and self.attention_attrs["op_type"] != "GroupQueryAttention": + k_input_to_attention = self.make_repeat_kv(layer_id, root_input=k_input_to_attention, past_kv=past_k, present_kv=present_k) + v_input_to_attention = self.make_repeat_kv(layer_id, root_input=v_input_to_attention, past_kv=past_v, present_kv=present_v) past_k, past_v, present_k, present_v = "", "", "", "" # Make attention node (e.g. MultiHeadAttention, GroupQueryAttention, etc.) attn_name = f"/model/layers.{layer_id}/attn/{self.attention_attrs['op_type']}" self.make_attention_op( attn_name, q_path=q_input_to_attention, k_path=k_input_to_attention, v_path=v_input_to_attention, - past_k=past_k, past_v=past_v, present_k=present_k, present_v=present_v, **kwargs, + past_k=past_k, past_v=past_v, present_k=present_k, present_v=present_v, + cos_cache=cos_cache_name, sin_cache=sin_cache_name, **kwargs, ) # Make MatMul node (output projection weight node) @@ -914,6 +968,14 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): self.layernorm_attrs["skip_input"] = f"{o_matmul_name if not o_bias_exists else o_add_name}/output_0" def make_mlp(self, layer_id, mlp, root_input): + if self.mlp_attrs["use_proj"]: + self.make_mlp_proj(layer_id, mlp, root_input) + elif self.mlp_attrs["use_fc"]: + self.make_mlp_fc(layer_id, mlp, root_input) + else: + raise NotImplementedError(f"The MLP layer type is not set.") + + def make_mlp_proj(self, layer_id, mlp, root_input): # Make nodes for the MLP subgraph # # root_input @@ -947,6 +1009,39 @@ def make_mlp(self, layer_id, mlp, root_input): # Assign output 0 of previous MatMul as skip input to next SkipLayerNorm self.layernorm_attrs["skip_input"] = f"{down_name}/output_0" + def make_mlp_fc(self, layer_id, mlp, root_input): + # Make nodes for the MLP subgraph + # + # root_input + # | + # FC1_MatMul + # | + # FC1_Add + # | + # ActFunc + # | + # FC2_MatMul + # | + # FC2_Add + + # Make first layer of fully connected nodes (FC1) + fc1_matmul_name = f"/model/layers.{layer_id}/mlp/fc1/MatMul" + self.make_matmul(mlp.fc1.weight.detach().numpy(), fc1_matmul_name, root_input) + fc1_add_name = f"/model/layers.{layer_id}/mlp/fc1/Add" + self.make_add_bias(mlp.fc1.bias.detach().numpy(), fc1_add_name, root_input=f"{fc1_matmul_name}/output_0") + + # Make activation function + act_fn_name = self.make_activation(layer_id, root_input=f"{fc1_add_name}/output_0") + + # Make second layer of fully connected nodes (FC2) + fc2_matmul_name = f"/model/layers.{layer_id}/mlp/fc2/MatMul" + self.make_matmul(mlp.fc2.weight.detach().numpy(), fc2_matmul_name, root_input=f"{act_fn_name}/output_0") + fc2_add_name = f"/model/layers.{layer_id}/mlp/fc2/Add" + self.make_add_bias(mlp.fc2.bias.detach().numpy(), fc2_add_name, root_input=f"{fc2_matmul_name}/output_0") + + # Assign output 0 of MLP layer as output of last layer + self.mlp_attrs["output_0"] = f"{fc2_add_name}/output_0" + def make_activation_with_mul(self, layer_id, root_input, activation, domain): # Make nodes for this activation subgraph # @@ -1071,6 +1166,12 @@ def has_final_norm(self, module, model): return hf_norm or hf_final_layernorm or gguf_final_norm def make_attention_mask_reformatting(self): + if self.attention_attrs["op_type"] == "GroupQueryAttention": + self.make_attention_mask_reformatting_for_gqa() + else: + self.make_attention_mask_reformatting_2d_to_4d() + + def make_attention_mask_reformatting_2d_to_4d(self): # Make nodes for the attention mask subgraphs that reformat the # 2D attention mask (B, S) to 4D causal attention mask (B, N, S, T) # @@ -1370,17 +1471,7 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for return expand_name - -class LlamaModel(Model): - def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): - super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) - self.model_inputs = ["input_ids", "attention_mask", "position_ids"] - - def make_attention_mask_reformatting(self): - if not self.attention_attrs["use_gqa"]: - super().make_attention_mask_reformatting() - return - + def make_attention_mask_reformatting_for_gqa(self): # Make nodes for the attention mask subgraph that calculates # attributes about the 2D attention mask to use in GroupQueryAttention # @@ -1420,12 +1511,6 @@ def make_attention_mask_reformatting(self): self.mask_attrs["seqlens_k"] = cast_1_name self.mask_attrs["total_seq_len"] = cast_2_name - -class MistralModel(LlamaModel): - def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): - super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) - self.position_ids_name = self.make_position_ids_reformatting() - def make_position_ids_reformatting(self): # Make nodes for the position ids reformatting subgraph # @@ -1461,62 +1546,42 @@ def make_position_ids_reformatting(self): return reshape_name + +class LlamaModel(Model): + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + + +class MistralModel(Model): + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + self.position_ids_name = f"{self.make_position_ids_reformatting()}/output_0" if not self.attention_attrs["use_rotemb_in_gqa"] else "position_ids" + def make_attention(self, layer_id, attention, root_input, **kwargs): - super().make_attention(layer_id, attention, root_input, position_ids=f"{self.position_ids_name}/output_0", **kwargs) + super().make_attention(layer_id, attention, root_input, position_ids=self.position_ids_name, **kwargs) -class PhiModel(LlamaModel): +class PhiModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) # self.input_shapes["position_ids"] = [1] # Note: This is optional and only needed if you want position_ids to be an int instead of a 2D tensor self.layernorm_attrs["simple"] = False self.rotemb_attrs["num_heads"] = self.num_attn_heads self.rotemb_attrs["rotary_embedding_dim"] = int(self.head_size * self.rotemb_attrs["partial_rotary_factor"]) + self.mlp_attrs["use_proj"], self.mlp_attrs["use_fc"] = False, True def make_rotary_embedding(self, rotemb, name, root_input, **kwargs): super().make_rotary_embedding(rotemb, name, root_input, num_heads=self.rotemb_attrs["num_heads"], rotary_embedding_dim=self.rotemb_attrs["rotary_embedding_dim"], **kwargs) - - def make_mlp(self, layer_id, mlp, root_input): - # Make nodes for the MLP subgraph - # - # root_input - # | - # FC1_MatMul - # | - # FC1_Add - # | - # FastGelu - # | - # FC2_MatMul - # | - # FC2_Add - - # Make first layer of fully connected nodes (FC1) - fc1_matmul_name = f"/model/layers.{layer_id}/mlp/fc1/MatMul" - self.make_matmul(mlp.fc1.weight.detach().numpy(), fc1_matmul_name, root_input) - fc1_add_name = f"/model/layers.{layer_id}/mlp/fc1/Add" - self.make_add_bias(mlp.fc1.bias.detach().numpy(), fc1_add_name, root_input=f"{fc1_matmul_name}/output_0") - - # Make activation function - fast_gelu_name = self.make_activation(layer_id, root_input=f"{fc1_add_name}/output_0") - - # Make second layer of fully connected nodes (FC2) - fc2_matmul_name = f"/model/layers.{layer_id}/mlp/fc2/MatMul" - self.make_matmul(mlp.fc2.weight.detach().numpy(), fc2_matmul_name, root_input=f"{fast_gelu_name}/output_0") - fc2_add_name = f"/model/layers.{layer_id}/mlp/fc2/Add" - self.make_add_bias(mlp.fc2.bias.detach().numpy(), fc2_add_name, root_input=f"{fc2_matmul_name}/output_0") - - return fc2_add_name def make_layer(self, layer_id, layer): # Each Phi decoder layer is defined as: # input_layernorm --> attention --> MLP --> residual_add self.make_layernorm(layer_id, layer.input_layernorm, skip=not self.layernorm_attrs["first_layernorm"], simple=self.layernorm_attrs["simple"], location="input") self.make_attention(layer_id, layer.self_attn, self.layernorm_attrs["output_0"]) - fc2_add_name = self.make_mlp(layer_id, layer.mlp, self.layernorm_attrs["output_0"]) + self.make_mlp(layer_id, layer.mlp, self.layernorm_attrs["output_0"]) residual_add_name = f"/model/layers.{layer_id}/residual_add/Add" - residual_add_inputs = [self.layernorm_attrs['skip_input'], f"{fc2_add_name}/output_0"] + residual_add_inputs = [self.layernorm_attrs['skip_input'], self.mlp_attrs["output_0"]] self.make_add(residual_add_name, residual_add_inputs, dtype=self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) self.layernorm_attrs["first_layernorm"] = False