diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index e7b7074783162..66c78f80f7910 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -213,6 +213,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "NhwcConv": self._infer_NhwcConv, "PackedAttention": self._infer_PackedAttention, "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, + "PagedAttention": self._infer_PagedAttention, "PythonOp": self._infer_PythonOp, "QuantizeLinear": self._infer_QuantizeLinear, "QuickGelu": self._infer_FastGelu, @@ -470,6 +471,7 @@ def _onnx_infer_single_node(self, node): "SkipLayerNormalization", "SkipSimplifiedLayerNormalization", "PackedAttention", + "PagedAttention", "PythonOp", "MultiHeadAttention", "GroupNorm", @@ -2412,6 +2414,9 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_PagedAttention(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + def _infer_GroupQueryAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type diff --git a/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py index bca5ace916082..9a66afe3ad4f9 100644 --- a/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py +++ b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py @@ -73,20 +73,32 @@ def unroll_function(self, func_name: str) -> None: return self.update_edges(edge_mapping) - def remove_dropout_layer(self) -> None: + def remove_function(self, func_name: str, input_id: int, output_id: int) -> None: """ - Removes the dropout layer in the model. + Removes the function in the model. """ - logging.info("Removing dropout layer...") edge_mapping = {} nodes_to_remove = [] for node in self.model.graph.node: - if node.op_type.find("Dropout") != -1: - assert len(node.input) == 1 - assert len(node.output) == 1 - edge_mapping[node.output[0]] = node.input[0] + if node.op_type.find(func_name) != -1: + edge_mapping[node.input[input_id]] = node.output[output_id] nodes_to_remove.append(node) for node in nodes_to_remove: self.model.graph.node.remove(node) self.update_edges(edge_mapping) + + def remove_dropout_layer(self) -> None: + """ + Removes the dropout layer in the model. + """ + logging.info("Removing dropout layer...") + self.remove_function("Dropout", 0, 0) + + def remove_lm_head_layer(self) -> None: + """ + Removes the LM head layer in the model. + """ + logging.info("Removing LM head layer...") + # bugbug: need to copy the right vi over + self.remove_function("Linear_lm_head", 2, 0) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index c65464a3069c5..4c43e4487bfb1 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -24,6 +24,7 @@ class AttentionOpType(Enum): Attention = "Attention" MultiHeadAttention = "MultiHeadAttention" GroupQueryAttention = "GroupQueryAttention" + PagedAttention = "PagedAttention" def __str__(self): return self.value diff --git a/onnxruntime/python/tools/transformers/models/phi2/README.md b/onnxruntime/python/tools/transformers/models/phi2/README.md index 526fdc3dd7863..da62bba0f02fb 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/README.md +++ b/onnxruntime/python/tools/transformers/models/phi2/README.md @@ -11,6 +11,7 @@ To export ONNX, PyTorch version 2.2.0 or higher is required. The [official websi **There are two options to run the conversion script:**\ _From source:_ ```bash +# Default onnxruntime package is built with CUDA 11.8. For CUDA 12.x, refer to https://onnxruntime.ai/docs/install/#python-installs pip install onnxruntime-gpu==1.17.0 # or onnxruntime==1.17.0 if using cpu git clone git@github.com:microsoft/onnxruntime.git cd onnxruntime/onnxruntime/python/tools/transformers diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index ac3ca40e41be0..b7881d064067d 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -136,14 +136,18 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): self.precision == Precision.FLOAT16 or self.precision == Precision.INT4 ) and self.attn_op_type != AttentionOpType.MultiHeadAttention: # We keep last three layers of Attention as float32 or bfloat16 to avoid overflow. - node_block_list = [ - "GroupQueryAttention_29", - "GroupQueryAttention_30", - "GroupQueryAttention_31", - "Attention_29", - "Attention_30", - "Attention_31", - ] + node_block_list = ( + [ + "GroupQueryAttention_29", + "GroupQueryAttention_30", + "GroupQueryAttention_31", + "Attention_29", + "Attention_30", + "Attention_31", + ] + if self.attn_op_type != AttentionOpType.PagedAttention + else [] + ) # TODO: temp setting for paged attention logging.info("Converting onnx model to float16/bfloat16...") optimizer.convert_float_to_float16( keep_io_types=False, @@ -220,6 +224,20 @@ def parse_arguments(): help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89", ) + parser.add_argument( + "--fp16_vllm", + required=False, + action="store_true", + help="Generate fp16 ONNX model for ORT VLLM", + ) + + parser.add_argument( + "--int4_vllm", + required=False, + action="store_true", + help="Generate int4 ONNX model for ORT VLLM", + ) + parser.add_argument( "--overwrite", required=False, @@ -336,6 +354,16 @@ def main(): Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"), ), + "fp16_vllm": ( + AttentionOpType.PagedAttention, + Precision.FLOAT16, + os.path.join(output_dir, "phi2_decoder_fp16_vllm.onnx"), + ), + "int4_vllm": ( + AttentionOpType.PagedAttention, + Precision.INT4, + os.path.join(output_dir, "phi2_decoder_int4_vllm.onnx"), + ), } if not args.skip_export: @@ -403,6 +431,22 @@ def run_optimize_phi2_onnx( ) ) + if args.fp16_vllm: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["fp16_vllm"]), + ) + ) + + if args.int4_vllm: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["int4_vllm"]), + ) + ) + [p.start() for p in processes] [p.join() for p in processes] @@ -450,8 +494,8 @@ def run_optimize_phi2_onnx( device_id=args.device_id, packed_kv=True, ) - if args.fp32_cpu or args.int4_cpu: - raise NotImplementedError("CPU inference example is not implemented yet.") + if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm: + raise NotImplementedError("CPU/vllm inference example is not implemented yet.") if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py index df8830b0d0495..e68c3120e3f09 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_phi.py +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -255,6 +255,30 @@ def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num ) return [node] + def paged_attn( + self, + inputs: List[str], + outputs: List[str], + prefix: str = "", + num_heads=32, + head_size=80, + scale=0.11180339753627777, + ): + assert len(inputs) == 6 + assert len(outputs) == 1 + node = helper.make_node( + "PagedAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "PagedAttention", + domain="vllm.ort.ext", + num_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_size, + scale=scale, + ) + return [node] + class Phi2PreProcessor(DynamoOnnxHelper): def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): @@ -288,32 +312,46 @@ def simplify_phi2_op_type(self): def process_graph_io(self, attn_op_type: AttentionOpType): self.use_attn = attn_op_type == AttentionOpType.Attention + self.use_vllm = attn_op_type == AttentionOpType.PagedAttention graph = self.model.graph new_inputs = [] for vi in graph.input: if "input_ids" in vi.name: vi_iid = helper.make_tensor_value_info( vi.name, - elem_type=TensorProto.INT32, + elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64, shape=["batch_size", "seq_len"], ) - vi_pid = helper.make_tensor_value_info( + vi_step = helper.make_tensor_value_info( "step", elem_type=TensorProto.INT64, shape=[1], ) + vi_pid = helper.make_tensor_value_info( + "position_ids", + elem_type=TensorProto.INT64, + shape=["batch_size", "seq_len"], + ) vi_mask = helper.make_tensor_value_info( "attention_mask", elem_type=TensorProto.INT32, shape=["batch_size", "seq_len"], ) - new_inputs.extend([vi_iid, vi_pid, vi_mask]) - if not self.use_attn: - if "past_key" in vi.name or "past_value" in vi.name: + vi_meta = helper.make_tensor_value_info( + "input_metadata", + elem_type=TensorProto.INT64, + shape=[1], + ) + new_inputs.extend([vi_iid, vi_step, vi_mask]) if not self.use_vllm else new_inputs.extend( + [vi_iid, vi_pid, vi_meta] + ) + if self.use_attn: + if "past_key" in vi.name: vi_cache = helper.make_tensor_value_info( - vi.name, + vi.name.replace("past_key", "past"), elem_type=vi.type.tensor_type.elem_type, shape=[ + 2, "batch_size", self.num_attention_heads, "past_seq_len", @@ -321,13 +359,32 @@ def process_graph_io(self, attn_op_type: AttentionOpType): ], ) new_inputs.extend([vi_cache]) - else: + elif self.use_vllm: if "past_key" in vi.name: vi_cache = helper.make_tensor_value_info( - vi.name.replace("past_key", "past"), + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"], + ) + new_inputs.extend([vi_cache]) + if "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "num_blocks", + "num_heads", + "head_size", + "block_size", + ], + ) + new_inputs.extend([vi_cache]) + else: + if "past_key" in vi.name or "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, elem_type=vi.type.tensor_type.elem_type, shape=[ - 2, "batch_size", self.num_attention_heads, "past_seq_len", @@ -344,19 +401,7 @@ def process_graph_io(self, attn_op_type: AttentionOpType): if i == 0: new_outputs.extend([vi]) else: - if not self.use_attn: - vi_cache = helper.make_tensor_value_info( - vi.name, - elem_type=vi.type.tensor_type.elem_type, - shape=[ - "batch_size", - self.num_attention_heads, - "total_seq_len", - self.hidden_size // self.num_attention_heads, - ], - ) - new_outputs.extend([vi_cache]) - else: + if self.use_attn: if "present_key" in vi.name: vi_cache = helper.make_tensor_value_info( vi.name.replace("present_key", "present"), @@ -370,6 +415,20 @@ def process_graph_io(self, attn_op_type: AttentionOpType): ], ) new_outputs.extend([vi_cache]) + elif self.use_vllm: + pass + else: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) graph.ClearField("output") graph.output.extend(new_outputs) @@ -385,6 +444,8 @@ def preprocess_onnx(self, attn_op_type: AttentionOpType): self.update_edges(self.phi2_edge_dict) self.simplify_phi2_op_type() self.remove_dropout_layer() + if attn_op_type == AttentionOpType.PagedAttention: + self.remove_lm_head_layer() self.process_graph_io(attn_op_type) @@ -694,7 +755,9 @@ def fuse( layer_known_edges_names.extend( [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias] ) - layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"]) + layer_known_edges_names.extend( + ["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"] + ) subgraph_nodes = [] subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"])) @@ -708,8 +771,9 @@ def fuse( subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) - subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_")) - subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_")) + pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step" + subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_")) + subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_")) if self.attn_op_type == AttentionOpType.MultiHeadAttention: subgraph_nodes.extend( self.mha( @@ -740,6 +804,13 @@ def fuse( self.model.add_initializer( numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name ) + elif self.attn_op_type == AttentionOpType.PagedAttention: + subgraph_nodes.extend( + self.paged_attn( + ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"], + ["attn_out"], + ) + ) else: past_name = f"past_{layer_id}" present_name = f"present_{layer_id}" @@ -798,6 +869,7 @@ def get_fused_operator_statistics(self): "Attention", "MultiHeadAttention", "GroupQueryAttention", + "PagedAttention", "Gelu", "BiasGelu", "FastGelu", @@ -821,7 +893,12 @@ def is_fully_optimized(self, fused_op_count=None): def op_count(op_name: str): return fused_op_count.get(op_name) or 0 - attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("GroupQueryAttention") + attention = ( + op_count("Attention") + + op_count("MultiHeadAttention") + + op_count("GroupQueryAttention") + + op_count("PagedAttention") + ) gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu") layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")