From c958bc6e31a0dba504273d3414f3c6b8caf5de77 Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Mon, 30 Oct 2023 21:30:34 +0000 Subject: [PATCH 01/13] add shardding support for llama model and update convert/benchmark script --- .../python/tools/symbolic_shape_infer.py | 1 + .../tools/transformers/convert_generation.py | 6 +- .../transformers/fusion_rotary_attention.py | 161 +++++- .../tools/transformers/models/llama/README.md | 9 + .../transformers/models/llama/benchmark.py | 98 ++-- .../models/llama/convert_to_onnx.py | 369 +++++++------ .../models/llama/dist_settings.py | 43 ++ .../transformers/models/llama/llama_inputs.py | 12 +- .../transformers/models/llama/llama_parity.py | 55 +- .../transformers/models/llama/llama_torch.py | 33 ++ .../models/llama/requirements.txt | 4 +- .../tools/transformers/models/llama/run.sh | 13 + .../transformers/models/llama/single_run.sh | 5 + .../python/tools/transformers/onnx_model.py | 12 + .../transformers/test_rotary_mha_fusion.py | 497 +++++++++++++++++- 15 files changed, 1040 insertions(+), 278 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/llama/dist_settings.py create mode 100644 onnxruntime/python/tools/transformers/models/llama/llama_torch.py create mode 100644 onnxruntime/python/tools/transformers/models/llama/run.sh create mode 100644 onnxruntime/python/tools/transformers/models/llama/single_run.sh diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ef1c46b83946a..6f12e26d4e80c 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -147,6 +147,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GatherElements": self._infer_GatherElements, "GatherND": self._infer_GatherND, "Identity": self._pass_on_shape_and_type, + "AllReduce": self._pass_on_shape_and_type, "If": self._infer_If, "Loop": self._infer_Loop, "MatMul": self._infer_MatMul, diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index b32ae64c5b0c0..be8eb6a4c42e8 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,7 +1272,7 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0): +def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1): past_seq_len = past_seq_len_input if past_seq_len not in model.get_graphs_input_names(): # Add model input for past sequence length @@ -1295,8 +1295,8 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads outputs=node.output, name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), domain="com.microsoft", - num_heads=node.attribute[0].i, - kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads, + num_heads=node.attribute[0].i // world_size, + kv_num_heads=node.attribute[0].i // world_size if kv_num_heads == 0 else kv_num_heads // world_size, is_past_bsnh=0, ) model.model.graph.node.remove(node) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 3c5029ac5752f..84eef46d692b3 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -323,6 +323,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # qkv_nodes_1 is for LLaMA-2 Microsoft # qkv_nodes_2 is for LLaMA-2 Hugging Face + # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model qkv_nodes = None qkv_nodes_1 = self.model.match_parent_path( normalize_node, @@ -334,18 +335,27 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["MatMul", "Reshape", "Transpose", "MatMul"], [1, 0, 0, 0], ) + qkv_nodes_3 = self.model.match_parent_path( + normalize_node, + ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0, 0], + ) if qkv_nodes_1 is not None: _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1 qkv_nodes = qkv_nodes_1 elif qkv_nodes_2 is not None: _, reshape_qkv, _, matmul_qkv = qkv_nodes_2 qkv_nodes = qkv_nodes_2 + elif qkv_nodes_3 is not None: + _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3 + qkv_nodes = qkv_nodes_3 else: logger.debug("fuse_rotary_attention: failed to match qkv nodes") return # v_nodes_1 is for LLaMA-2 Microsoft # v_nodes_3 is for LLaMA-2 Hugging Face + # v_nodes_4 is for LLaMA-2 70B model past_v, present_v, past_seq_len = "", "", "" v_nodes = None v_nodes_1 = self.model.match_parent_path( @@ -363,6 +373,48 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape", "MatMul"], [1, 0, 0], ) + _,v_nodes_4,_ = self.model.match_parent_paths_all( + matmul_qkv, + [ + ( + ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Expand", "Where", "Equal", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Expand", "Where", "Equal", "Mul", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Expand", "Where", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Expand", "Where", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Mul", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 2, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 3, 0, 0, 0, 1, 0, 0], + ), + ], + output_name_to_node = None + ) if v_nodes_1 is not None: reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 v_nodes = v_nodes_1 @@ -388,6 +440,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): transpose_v, reshape_v, matmul_v = v_nodes_3 v_nodes = v_nodes_3 present_v = transpose_v.output[0] + elif v_nodes_4 is not None and len(v_nodes_4) == 9: + logger.debug('fuse_rotary_attention: v_nodes_4') + logger.debug('*' * 30) + for temp_path in v_nodes_4: + logger.debug('fuse_rotary_attention: path for v_nodes_4') + for temp_node in temp_path: + logger.debug(f'temp_node: {temp_node}') + logger.debug('*' * 30) + + concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:] + v_nodes = v_nodes_4 + past_v = concat_v.input[0] + present_v = concat_v.output[0] + logger.debug(f'transpose_v: {transpose_v}') + logger.debug(f'reshape_v: {reshape_v}') + logger.debug(f'matmul_v: {matmul_v}') + logger.debug(f'past_v: {past_v}') + logger.debug(f'present_v: {present_v}') else: logger.debug("fuse_rotary_attention: failed to match v path") return @@ -445,6 +515,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # k_nodes_1 is for LLaMA-2 Microsoft # k_nodes_2 is for LLaMA-2 Hugging Face + # k_nodes_4 is for distributed LLaMA-2 Hugging Face past_k, present_k = "", "" k_nodes = None k_nodes_1 = self.model.match_parent_path( @@ -462,6 +533,48 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0, 0], ) + _, k_nodes_4, _ = self.model.match_parent_paths_all( + matmul_qk, + [ + ( + ["Transpose", "Reshape", "Expand", "Unsqueeze", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + ["Transpose", "Reshape", "Expand", "Where", "Equal", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + ["Transpose", "Reshape", "Expand", "Where", "Equal", "Mul", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + ), + ( + ["Transpose", "Reshape", "Expand", "Where", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ( + ["Transpose", "Reshape", "Expand", "Where", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0], + ), + ( + ["Transpose", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + ["Transpose", "Reshape", "Concat", "Unsqueeze", "Mul", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + ["Transpose", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0], + ), + ( + ["Transpose", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ], + output_name_to_node = None + ) if k_nodes_1 is not None: reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 k_nodes = k_nodes_1 @@ -489,6 +602,26 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes = k_nodes_3 past_k = concat_k.input[0] present_k = concat_k.output[0] + elif k_nodes_4 is not None and len(k_nodes_4) == 9: + logger.debug('fuse_rotary_attention: k_nodes_4') + logger.debug('*' * 30) + for temp_path in k_nodes_4: + logger.debug('fuse_rotary_attention: path for k_nodes_4') + for temp_node in temp_path: + logger.debug(f'temp_node: {temp_node}') + logger.debug('*' * 30) + + reshape_k, matmul_k = k_nodes_4[0][-2:] + concat_k, rotary_k = k_nodes_4[0][-5:-3] + k_nodes = k_nodes_4 + past_k = concat_k.input[0] + present_k = concat_k.output[0] + logger.debug(f'reshape_k: {reshape_k}') + logger.debug(f'matmul_k: {matmul_k}') + logger.debug(f'concat_k: {concat_k}') + logger.debug(f'rotary_k: {rotary_k}') + logger.debug(f'past_k: {past_k}') + logger.debug(f'present_k: {present_k}') else: logger.debug("fuse_rotary_attention: failed to match k nodes") return @@ -536,7 +669,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return root_output = reshape_qkv_2.output[0] - elif qkv_nodes == qkv_nodes_2: + elif qkv_nodes == qkv_nodes_2 or qkv_nodes == qkv_nodes_3: if not self.check_runtime_shape_paths_for_nodes( reshape_qkv, reshape_q, @@ -557,6 +690,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key) rotary_k.output[0] = rotary_k.name + "_output_0" + if qkv_nodes == qkv_nodes_3: + qkv_nodes = qkv_nodes[1:] + new_node = self.create_mha_node( matmul_q.input[0], root_output, @@ -578,7 +714,18 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.nodes_to_remove.extend(qkv_nodes[1:]) - self.nodes_to_remove.extend(v_nodes[:-1]) + + if v_nodes != v_nodes_4: + self.nodes_to_remove.extend(v_nodes[:-1]) + else: + remove_dic = {} + node_keep_name = v_nodes[0][-1].name + for temp_path in v_nodes: + for temp_node in temp_path: + if temp_node.name not in remove_dic and temp_node.name != node_keep_name: + remove_dic[temp_node.name] = temp_node + self.nodes_to_remove.extend(list(remove_dic.values())) + self.nodes_to_remove.extend(qk_nodes) if k_nodes == k_nodes_1: @@ -592,6 +739,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.append(k_nodes[1]) self.nodes_to_remove.append(k_nodes[3]) self.nodes_to_remove.append(k_nodes[4]) + elif k_nodes == k_nodes_4: + remove_dic = {} + node_keep_names = [k_nodes[0][-1].name, k_nodes[0][-4].name] + for temp_path in k_nodes: + for temp_node in temp_path: + if temp_node.name not in remove_dic and temp_node.name not in node_keep_names: + remove_dic[temp_node.name] = temp_node + self.nodes_to_remove.extend(list(remove_dic.values())) if q_nodes == q_nodes_1: self.nodes_to_remove.extend(q_nodes[:-2]) @@ -1041,4 +1196,4 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): self.increase_counter(self.base_name) self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name self.nodes_to_add.append(rotary_emb_node) - self.prune_graph = True + self.prune_graph = True \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 9619e6cb52a91..7f484d659cb7b 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -153,6 +153,15 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu ``` +Export Sharded model, llama-70b into 4 partitions +``` +# From source: +$ 1. Get OnnxRuntime code from https://github.com/frankdongms/transformers/tree/frdong/shard_llama or +$ wait until PR: https://github.com/huggingface/transformers/pull/27119 got merged into HF transformers +$ 2. Build OnnxRuntime from source with NCCL enabled, sample command: ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ +$ 3. Shard and export llama-70b model: CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh 4 -m meta-llama/Llama-2-7b-hf --output llama2-7b-dis2 --precision fp16 --execution_provider cuda +``` + ## Benchmark LLaMA-2 Here are some examples of how you can benchmark LLaMA-2. diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index a721979eb0bcb..fe6f990a31741 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,7 +11,7 @@ import onnx import psutil import torch -from benchmark_helper import setup_logger +from onnxruntime.transformers.benchmark_helper import setup_logger from llama_inputs import ( convert_inputs_for_ort, get_merged_sample_with_past_kv_inputs, @@ -27,6 +27,8 @@ import onnxruntime as ort from onnxruntime.transformers.benchmark_helper import measure_memory +from dist_settings import get_rank, get_size + logger = logging.getLogger(__name__) @@ -118,6 +120,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=0, use_fp16=args.use_fp16, return_dict=True, + world_size = args.world_size, ) iter_inputs = get_merged_sample_with_past_kv_inputs( args.config, @@ -127,6 +130,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=args.sequence_length, use_fp16=args.use_fp16, return_dict=True, + world_size = args.world_size, ) init_inputs = convert_inputs_for_ort( init_inputs, @@ -135,7 +139,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=0, max_seq_len=max_seq_len, device=args.device, - device_id=args.device_id, + device_id=args.rank, ) iter_inputs = convert_inputs_for_ort( iter_inputs, @@ -144,7 +148,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=args.sequence_length, max_seq_len=max_seq_len, device=args.device, - device_id=args.device_id, + device_id=args.rank, ) elif args.benchmark_type == "ort-msft": @@ -174,7 +178,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=0, max_seq_len=max_seq_len, device=args.device, - device_id=args.device_id, + device_id=args.rank, ) iter_inputs = convert_inputs_for_ort( iter_inputs, @@ -183,7 +187,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=args.sequence_length, max_seq_len=max_seq_len, device=args.device, - device_id=args.device_id, + device_id=args.rank, ) else: @@ -261,10 +265,10 @@ def get_model(args: argparse.Namespace): if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx - logger.info(f"Loading model from {args.ort_model_path}") + logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}") start_time = time.time() model = ort.InferenceSession( - args.ort_model_path, + args.ort_model_path.format(args.rank), sess_options, providers=[args.execution_provider], ) @@ -286,56 +290,38 @@ def time_fn(args, fn, inputs): outputs = fn(inputs) logger.info(outputs) - input_sync = ( # noqa: E731 - lambda *kwargs: args.io_binding.synchronize_inputs() - if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize - else lambda *kwargs: torch.cuda.synchronize() - if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize - else lambda *kwargs: None # no-op function - ) - - output_sync = ( # noqa: E731 - lambda *kwargs: args.io_binding.synchronize_outputs() - if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize - else lambda *kwargs: torch.cuda.synchronize() - if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize - else lambda *kwargs: None # no-op function - ) - for _ in warmup_range: - input_sync() fn(inputs) - output_sync() # Benchmark - total_time = 0 + if args.device != "cpu": + torch.cuda.synchronize() + start_time = time.time() + bench_range = ( range(args.num_runs) if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: - input_sync() - start_time = time.time() - fn(inputs) - output_sync() - end_time = time.time() - - total_time += end_time - start_time + if args.device != "cpu": + torch.cuda.synchronize() + end_time = time.time() # Newline print after trange in order to print metrics on new lines without progress bar on same line if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}: logger.info("") - latency = total_time / args.num_runs + latency = (end_time - start_time) / args.num_runs throughput = args.batch_size / latency - logger.info(f"Batch Size: {args.batch_size}") - logger.info(f"Sequence Length: {args.sequence_length}") - logger.info(f"Latency: {latency} s") - logger.info(f"Throughput: {throughput} tps") + if args.rank == 0: + logger.info(f"Batch Size: {args.batch_size}") + logger.info(f"Sequence Length: {args.sequence_length}") + logger.info(f"Latency: {latency:.4f} s") + logger.info(f"Throughput: {throughput:.4f} tps") return @@ -375,7 +361,8 @@ def measure_fn(args, fn, inputs): process.cpu_percent(interval=0.1) fn(inputs) - logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%") + if args.rank == 0: + logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%") # Measure memory usage gc.collect() @@ -484,9 +471,8 @@ def prepare_ort_inputs(inputs): name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] ) else: - io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) + io_binding.bind_output(name, device_type=args.device, device_id=args.rank) - setattr(args, "io_binding", io_binding) # noqa: B010 return io_binding return inputs @@ -523,12 +509,14 @@ def without_io_binding(inputs): return # ORT evaluations - logger.info("\nEvaluating `model(inputs)` step to get past_key_values") + if args.rank == 0: + logger.info("\nEvaluating `model(inputs)` step to get past_key_values") ort_init_inputs = prepare_ort_inputs(init_inputs) time_fn(args, generate_fn, ort_init_inputs) measure_fn(args, generate_fn, ort_init_inputs) - logger.info("\nEvaluating `model(inputs)` step with past_key_values") + if args.rank == 0: + logger.info("\nEvaluating `model(inputs)` step with past_key_values") ort_iter_inputs = prepare_ort_inputs(iter_inputs) time_fn(args, generate_fn, ort_iter_inputs) measure_fn(args, generate_fn, ort_iter_inputs) @@ -543,7 +531,7 @@ def run_inference(args, init_inputs, iter_inputs, model): raise Exception(f"Cannot recognize {args.benchmark_type}") -def get_args(): +def get_args(rank = 0): parser = argparse.ArgumentParser() parser.add_argument( "-bt", @@ -601,7 +589,7 @@ def get_args(): parser.add_argument( "-s", "--sequence-lengths", - default="8 16 32 64 128 256 512", + default="32 64 128 256 512", ) parser.add_argument( "-d", @@ -638,9 +626,9 @@ def get_args(): if "ort" in args.benchmark_type: setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010 if args.execution_provider == "CUDAExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = (args.execution_provider, {"device_id": rank}) elif args.execution_provider == "ROCMExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = (args.execution_provider, {"device_id": rank}) args.device = "cuda" # Check that paths have been specified for any benchmarking with ORT @@ -667,14 +655,19 @@ def get_args(): def main(): - args = get_args() + rank = get_rank() + world_size = get_size() + + args = get_args(rank) setup_logger(args.verbose) logger.info(args.__dict__) torch.backends.cudnn.benchmark = True + setattr(args, "rank", rank) + setattr(args, "world_size", world_size) tokenizer = LlamaTokenizer.from_pretrained(args.model_name) config = LlamaConfig.from_pretrained(args.model_name) - target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device + target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device use_fp16 = args.precision == "fp16" setattr(args, "tokenizer", tokenizer) # noqa: B010 @@ -688,7 +681,7 @@ def main(): # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA) if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}: - onnx_model = onnx.load_model(args.ort_model_path, load_external_data=False) + onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False) gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" @@ -698,7 +691,8 @@ def main(): # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): - logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") + if args.rank == 0: + logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") setattr(args, "batch_size", int(batch_size)) # noqa: B010 setattr(args, "sequence_length", int(sequence_length)) # noqa: B010 @@ -707,4 +701,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 69603fd3ed488..41e743ff6045c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -1,13 +1,14 @@ import argparse import logging import os +import shutil import tempfile from itertools import chain from typing import List import onnx import torch -from benchmark_helper import Precision, prepare_environment, setup_logger +from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger from convert_generation import replace_mha_with_gqa from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check @@ -19,8 +20,11 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer -logger = logging.getLogger("") +from dist_settings import init_dist, get_rank, get_size, barrier +from llama_torch import setup_torch_model +logger = logging.getLogger("") +init_dist() def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): dynamic_axes = {} @@ -129,7 +133,7 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st # del onnx_model # temp_dir.cleanup() # -def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int=0, world_size: int=1): from torch._dynamo import config config.capture_scalar_outputs = True @@ -150,9 +154,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_model_fp32.onnx.data") + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data") del onnx_model os.system( f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" @@ -160,7 +164,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll # Export decoder_with_past_model.onnx input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length + l_config, device, batch_size, sequence_length, world_size = world_size ) temp_dir = args.output # tempfile.TemporaryDirectory() temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") @@ -172,9 +176,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_with_past_model_fp32.onnx.data") + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data") del onnx_model os.system( f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" @@ -182,11 +186,14 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") +def _prepare_dir(dir_path): + if not os.path.exists(dir_path): + os.makedirs(dir_path) -def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int=0, world_size: int=1): # Dummy values for export batch_size, sequence_length = 2, 8 - device = torch.device("cpu") + device = llama.device # Export decoder_model.onnx decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length) @@ -199,8 +206,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon ), ] dynamic_axes = get_model_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid use system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f'./temp_{rank}' + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_inputs, @@ -218,18 +229,18 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) # Export decoder_with_past_model.onnx - decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length) + decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length, use_fp16 = args.precision == Precision.FLOAT16, world_size = world_size) input_names = [ "input_ids", "attention_mask", @@ -247,8 +258,10 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon ), ] dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid use system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_with_past_inputs, @@ -266,27 +279,27 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_with_past_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) - logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + logger.info(f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!") -def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int=0, world_size: int=1): # Dummy values for export batch_size, sequence_length, past_sequence_length = 2, 8, 0 - device = torch.device("cpu") + device = llama.device # Export decoder_merged_model.onnx decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length, past_sequence_length + l_config, device, batch_size, sequence_length, past_sequence_length, use_fp16 = args.precision == Precision.FLOAT16, world_size = world_size ) input_names = [ "input_ids", @@ -305,8 +318,12 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi ), ] dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid use system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f'./temp_{rank}' + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_merged_inputs, @@ -324,17 +341,17 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_merged_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) - logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + logger.info(f"The {args.model_name} merged ONNX model has been successfully created with the TorchScript exporter!") # Optimize the model as FP32 @@ -357,12 +374,12 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str): remove_existing_model(input_path) -def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str]): - decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx") +def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str], rank: int=0, world_size: int=1): + decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx") decoder_with_past_model_fp16_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx" + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx" ) - decoder_merged_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp16.onnx") + decoder_merged_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx") new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path] logger.info("Converting to float16...") @@ -370,7 +387,7 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: if os.path.exists(fp32_path): model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True)) model.convert_float_to_float16(keep_io_types=False) - model = use_group_query_attention(config, model) + model = use_group_query_attention(config, model, world_size) model.save_model_to_file(fp16_path, use_external_data_format=True) del model logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!") @@ -380,9 +397,9 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: return new_paths -def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel): +def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int=1): # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes - fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads) + fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size) fp16_model_opt.prune_graph() fp16_model_opt.update_graph(allow_remove_graph_inputs=True) return fp16_model_opt @@ -480,7 +497,6 @@ def remove_existing_files(output_path: str): os.remove(filepath) logger.warning(f"Removed {filepath}") - def get_args(): parser = argparse.ArgumentParser() @@ -655,6 +671,14 @@ def get_args(): ) parser.set_defaults(use_dynamo_export=False) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + args = parser.parse_args() return args @@ -678,153 +702,161 @@ def main(): setattr(args, "use_auth_token", use_auth_token) # noqa: B010 location = args.model_name if use_auth_token else args.input - l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) - llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, use_cache=True) + l_config, llama = setup_torch_model(args, location, use_auth_token) original_model_name = args.model_name setattr(args, "original_model_name", original_model_name) # noqa: B010 args.model_name = args.model_name.split("/")[-1] - # Set model paths for FP32 model - decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") - decoder_with_past_model_fp32_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx" - ) - decoder_merged_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - - missing_separate_exports = ( - args.no_merged - and not os.path.exists(decoder_model_fp32_path) - and not os.path.exists(decoder_with_past_model_fp32_path) - ) - missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) - - # Export to ONNX - if missing_separate_exports or missing_merged_export: - if args.use_dynamo_export and missing_separate_exports: - logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") - logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") - logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") - logger.warning( - "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" - ) - logger.warning( - "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." - ) - run_dynamo_export(args, l_config, llama) - elif args.no_merged: - run_torchscript_separate_export(args, l_config, llama) - else: - run_torchscript_merged_export(args, l_config, llama) - - # Set model paths to store FP32 optimized model - decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") - decoder_with_past_model_fp32_opt_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp32_opt.onnx" - ) - decoder_merged_model_fp32_opt_path = os.path.join( - args.output, f"{args.model_name}_decoder_merged_model_fp32_opt.onnx" - ) - new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path] + world_size = get_size() + rank = get_rank() - # Run the optimizer script - logger.info("Optimizing models...") - for orig_path, opt_path in zip(old_paths, new_paths): - if os.path.exists(orig_path): - optimize_export(l_config, input_path=orig_path, output_path=opt_path) - - # Re-assign default FP32 model paths as their optimized versions - decoder_model_fp32_path = decoder_model_fp32_opt_path - decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path - decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + barrier() + for i in range(world_size): + if i == rank: + # Set model paths for FP32 model + decoder_model_fp32_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") + decoder_with_past_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx" + ) + decoder_merged_model_fp32_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx") + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - logger.info( - f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" - ) - - # Change precision of exported models from FP32 - if args.precision == Precision.FLOAT16: - new_paths = convert_to_float16(args, l_config, old_paths) - - elif args.precision == Precision.INT8: - decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx") - decoder_with_past_model_int8_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx" - ) - decoder_merged_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int8.onnx") - new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] - - if args.quantization_method == "smooth_quant": - if not args.no_merged: - logger.error("SmoothQuant must be used on separately exported models") - else: - logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8") - smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) - - elif args.quantization_method == "quantize_dynamic": - logger.warning( - "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + missing_separate_exports = ( + args.no_merged + and not os.path.exists(decoder_model_fp32_path) + and not os.path.exists(decoder_with_past_model_fp32_path) + ) + missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) + + # Export to ONNX + if missing_separate_exports or missing_merged_export: + if args.use_dynamo_export and missing_separate_exports: + logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") + logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") + logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") + logger.warning( + "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" + ) + logger.warning( + "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." + ) + run_dynamo_export(args, l_config, llama) + elif args.no_merged: + run_torchscript_separate_export(args, l_config, llama, rank, world_size) + else: + run_torchscript_merged_export(args, l_config, llama, rank, world_size) + + # Set model paths to store FP32 optimized model + decoder_model_fp32_opt_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx") + decoder_with_past_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx" + ) + decoder_merged_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx" + ) + new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path] + + # Run the optimizer script + logger.info("Optimizing models...") + for orig_path, opt_path in zip(old_paths, new_paths): + if os.path.exists(orig_path): + optimize_export(l_config, input_path=orig_path, output_path=opt_path) + + # Re-assign default FP32 model paths as their optimized versions + decoder_model_fp32_path = decoder_model_fp32_opt_path + decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + logger.info( + f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" ) - logger.info("Quantizing to int8...") - for fp32_path, int8_path in zip(old_paths, new_paths): - if os.path.exists(fp32_path): - ort_quantization.quantize_dynamic( - fp32_path, - int8_path, - op_types_to_quantize=["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"], - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, + # Change precision of exported models from FP32 + if args.precision == Precision.FLOAT16: + _ = convert_to_float16(args, l_config, old_paths, rank, world_size) + + elif args.precision == Precision.INT8: + decoder_model_int8_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx") + decoder_with_past_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" + ) + decoder_merged_model_int8_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx") + new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] + + if args.quantization_method == "smooth_quant": + if not args.no_merged: + logger.error("SmoothQuant must be used on separately exported models") + else: + logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8") + smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) + + elif args.quantization_method == "quantize_dynamic": + logger.warning( + "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." ) - logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!") - remove_existing_model(decoder_model_fp32_path) - logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + logger.info("Quantizing to int8...") + for fp32_path, int8_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + ort_quantization.quantize_dynamic( + fp32_path, + int8_path, + op_types_to_quantize=["MatMul", "Gemm", "Gather"] + if args.quantize_embedding_layer + else ["MatMul", "Gemm"], + per_channel=args.quantize_per_channel, + reduce_range=args.quantize_reduce_range, + use_external_data_format=True, + extra_options={"MatMulConstBOnly": True}, + ) + logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!") + remove_existing_model(decoder_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + else: + raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") + + elif args.precision == Precision.INT4: + if args.execution_provider != "cpu": + old_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) + + decoder_model_int4_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx") + decoder_with_past_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" + ) + decoder_merged_model_int4_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx") + new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + for fp_path, int4_path in zip(old_paths, new_paths): + if os.path.exists(fp_path): + model = onnx.load_model(fp_path, load_external_data=True) + quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) + quant.process() + quant.model.save_model_to_file(int4_path, use_external_data_format=True) + del model + del quant + logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + remove_existing_model(fp_path) + + # Delete LLaMA model from memory since it will be loaded again during parity + del llama + barrier() + - else: - raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - - elif args.precision == Precision.INT4: - if args.execution_provider != "cpu": - old_paths = convert_to_float16(args, l_config, old_paths) - - decoder_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int4.onnx") - decoder_with_past_model_int4_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_int4.onnx" - ) - decoder_merged_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int4.onnx") - new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] - - for fp_path, int4_path in zip(old_paths, new_paths): - if os.path.exists(fp_path): - model = onnx.load_model(fp_path, load_external_data=True) - quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) - quant.process() - quant.model.save_model_to_file(int4_path, use_external_data_format=True) - del model - del quant - logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") - remove_existing_model(fp_path) - - del llama # Delete LLaMA model from memory since it will be loaded again during parity check logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models args.precision = ( "fp32" - if args.precision in {Precision.INT8, Precision.FLOAT32} - or (args.precision == Precision.INT4 and args.execution_provider == "cpu") + if args.precision in {"int8", "fp32"} or (args.precision == Precision.INT4 and args.execution_provider == "cpu") else "fp16" ) # Verify parity on all saved ONNX models for filename in os.listdir(args.output): - if ".data" in filename or ".onnx" not in filename: + if ".data" in filename or ".onnx" not in filename or args.precision not in filename or f'rank_{rank}' not in filename: continue parity_cmd = [ @@ -834,10 +866,10 @@ def main(): os.path.join(args.output, filename), "-ep", args.execution_provider, - "-id", - args.device_id, "-fp", args.precision, + "--cache_dir", + args.cache_dir ] if "with_past" in filename: parity_cmd.append("--use_past_kv") @@ -845,10 +877,11 @@ def main(): parity_cmd.append("--merged") try: + logger.debug(f'check parity with cmd: {parity_cmd}') parity_check(parity_cmd) except Exception as e: logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py new file mode 100644 index 0000000000000..3040482612644 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -0,0 +1,43 @@ +import os +import torch +import torch.distributed as dist + +from mpi4py import MPI + + +def init_dist(): + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: + local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) + else: + local_rank = 0 + rank = 0 + world_size = 1 + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank) + device = torch.device(local_rank) + return device + +comm = MPI.COMM_WORLD + + +def get_rank(): + return comm.Get_rank() + + +def get_size(): + return comm.Get_size() + + +def barrier(): + comm.Barrier() + + +def print_out(*args): + if get_rank() == 0: + print(*args) \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 2652e9f0ca64e..bcededd463272 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -46,12 +46,13 @@ def get_sample_with_past_kv_inputs( past_seq_len: int, use_fp16: bool = False, return_dict: bool = False, + world_size: int=1 ): input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64) attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64) # position_ids is of shape (batch_size, 1) position_ids = get_position_ids(attention_mask, use_past_kv=True) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16, world_size = world_size) if not return_dict: return (input_ids, attention_mask, position_ids, past_kv) @@ -74,6 +75,7 @@ def get_merged_sample_with_past_kv_inputs( past_seq_len: int, use_fp16: bool = False, return_dict: bool = False, + world_size: int=1 ): input_ids = torch.randint( low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 @@ -81,7 +83,7 @@ def get_merged_sample_with_past_kv_inputs( attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16, world_size = world_size) if not return_dict: return (input_ids, attention_mask, position_ids, past_kv) @@ -97,9 +99,9 @@ def get_merged_sample_with_past_kv_inputs( # Create past_key_values def get_sample_past_kv_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool + config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int=1 ): - num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads + num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( @@ -201,4 +203,4 @@ def get_msft_sample_inputs( } ) - return ort_inputs + return ort_inputs \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 4353d0606803d..fb9b7f584edca 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -13,6 +13,8 @@ get_sample_inputs, get_sample_with_past_kv_inputs, ) +from llama_torch import setup_torch_model +from dist_settings import get_rank, get_size from transformers import LlamaConfig, LlamaForCausalLM import onnxruntime as ort @@ -28,6 +30,7 @@ def get_sequence_lengths(args: argparse.Namespace): def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity + world_size = get_size() batch_size = 2 past_sequence_length, sequence_length, _ = get_sequence_lengths(args) @@ -40,13 +43,14 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): past_sequence_length, use_fp16=args.use_fp16, return_dict=True, + world_size = world_size ) elif args.use_past_kv: inputs = get_sample_with_past_kv_inputs( - config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True + config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True, world_size = world_size ) else: - inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) + inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True, world_size = world_size) return inputs @@ -70,7 +74,7 @@ def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, input name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] ) else: - io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id)) + io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.rank)) return io_binding @@ -78,6 +82,8 @@ def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, input def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): inputs = get_inputs(args, config) + logger.debug(f'torch input: {inputs}') + # Run inference with PyTorch if args.execution_provider != "cpu": torch.cuda.synchronize() @@ -87,6 +93,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama torch.cuda.synchronize() end_time = time.time() logger.info(f"PyTorch took {end_time - start_time} s") + del pt_model # Run inference with ORT past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) @@ -97,12 +104,14 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, device=args.execution_provider, - device_id=int(args.device_id), + device_id=int(args.rank), ) + logger.debug(f'ORT input: {inputs}') + ep = f"{args.execution_provider.upper()}ExecutionProvider" if ep == "CUDAExecutionProvider": - ep = (ep, {"device_id": args.device_id}) + ep = (ep, {"device_id": args.rank}) ort_model = ort.InferenceSession( args.onnx_model_path, sess_options=ort.SessionOptions(), @@ -113,13 +122,14 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama if args.execution_provider != "cpu": io_binding = add_io_bindings(args, ort_model, inputs) - io_binding.synchronize_inputs() + torch.cuda.synchronize() start_time = time.time() ort_model.run_with_iobinding(io_binding) - io_binding.synchronize_outputs() + torch.cuda.synchronize() end_time = time.time() ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits + del ort_model else: start_time = time.time() @@ -179,15 +189,6 @@ def get_args(argv: List[str]): help="Execution provider to verify parity with", ) - parser.add_argument( - "-id", - "--device-id", - required=False, - type=str, - default="0", - help="Device ID for GPUs", - ) - parser.add_argument( "-v", "--verbose", @@ -219,6 +220,14 @@ def get_args(argv: List[str]): help="Precision of model", ) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + args = parser.parse_args() if argv == [] else parser.parse_args(argv) # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models @@ -234,21 +243,17 @@ def main(argv: List[str] = []): # noqa: B006 args = get_args(argv) setup_logger(args.verbose) logger.info(f"Arguments: {args}") + rank = get_rank() # Load model and config setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 - setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{args.device_id}") # noqa: B010 + setattr(args, "rank", rank) + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 setattr(args, "device", torch.device(args.device_name)) # noqa: B010 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory - config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) - llama = LlamaForCausalLM.from_pretrained( - location, - torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), - use_auth_token=use_auth_token, - use_cache=True, - ).to(args.device) + config, llama = setup_torch_model(args, location, use_auth_token, torch_dtype=(torch.float16 if args.use_fp16 else torch.float32)) if not args.merged: verify_parity(args, config, llama) @@ -266,4 +271,4 @@ def main(argv: List[str] = []): # noqa: B006 seed = 2 np.random.seed(seed) torch.manual_seed(seed) - main() + main() \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py new file mode 100644 index 0000000000000..312854de913ed --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -0,0 +1,33 @@ +import os +import logging + +import torch +from transformers import LlamaConfig, LlamaForCausalLM + +from dist_settings import get_rank, get_size, barrier + +logger = logging.getLogger("") + +def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, use_cuda=True): + world_size = get_size() + logger.info(f'world_size: {world_size}') + rank = get_rank() + barrier() + + if not os.path.exists(args.cache_dir): + os.makedirs(args.cache_dir) + + for i in range(world_size): + if i == rank: + l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) + l_config.use_cache = True + llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, config=l_config, + torch_dtype=torch_dtype, cache_dir=args.cache_dir) + if world_size > 1: + llama.parallel_model() + if use_cuda: + llama.to(torch.device(rank)) + llama.eval() + llama.requires_grad_(False) + barrier() + return l_config, llama \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 4210f36982aef..5500bc0bdf2c1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -3,4 +3,6 @@ transformers>=4.33.2 torch>=2.2.0.dev20230920 onnx>=1.14.0 datasets>=2.8.0 -protobuf==3.20.2 \ No newline at end of file +protobuf==3.20.2 +mpi4py +psutil \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/run.sh b/onnxruntime/python/tools/transformers/models/llama/run.sh new file mode 100644 index 0000000000000..b980601c02252 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/run.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +NUM_GPUS=${1:-1} + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode $NUM_GPUS --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI bash single_run.sh ${@:2}" + +set -x +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/single_run.sh b/onnxruntime/python/tools/transformers/models/llama/single_run.sh new file mode 100644 index 0000000000000..82f34f79b7bf4 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/single_run.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +CMD="python convert_to_onnx.py ${@}" + +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 07870373e90b0..66ec0de88b44c 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -337,6 +337,18 @@ def match_parent_paths(self, node, paths, output_name_to_node): return i, matched, return_indice return -1, None, None + def match_parent_paths_all(self, node, paths, output_name_to_node): + match_i, matches, return_indices = [], [], [] + for i, path in enumerate(paths): + assert isinstance(path, (List, Tuple)) + return_indice = [] + matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) + if matched: + match_i.append(i) + matches.append(matched) + return_indices.append(return_indice) + return match_i, matches, return_indices + def match_parent_path( self, node, diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py index fedba2a25dfc2..90e43246b7555 100644 --- a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -96,7 +96,7 @@ def create_inputs_and_outputs(self, model_type: str): helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size), ] - if model_type in {"past", "merged", "llama2_msft"}: + if model_type in {"past", "merged", "llama2_msft", "70b_distributed_merged"}: inputs.extend( [ helper.make_tensor_value_info( @@ -164,14 +164,14 @@ def get_first_rope_input(node_type: str): if is_fused or model_type == "llama2_msft": # q_out/k_out return f"{node_type}_out" - if model_type in {"no_past", "past", "merged"}: + if model_type in {"no_past", "past", "merged", "70b_distributed_merged"}: if node_type == "k": return "k_before_rope" return "q_before_rope" return "" def get_first_rope_output(node_type: str): - if is_fused or model_type in {"llama2_msft", "past", "merged"}: + if is_fused or model_type in {"llama2_msft", "past", "merged", "70b_distributed_merged"}: if node_type == "q": return "q_rope" return "k_rope" @@ -295,23 +295,244 @@ def create_k_path_hf(self, model_type: str): ) k_nodes = [reshape_k_node, transpose_k_1_node] - if model_type in {"past", "merged"}: + if model_type == "70b_distributed_merged": concat_k_node = helper.make_node( + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) + k_nodes.append(concat_k_node) + + shape_k1 = helper.make_node( + "Shape", + inputs=["present_value"], + outputs=["shape_k1_out"], + name="Shape_k1" + ) + k_nodes.append(shape_k1) + shape_k2 = helper.make_node( + "Shape", + inputs=["present_value"], + outputs=["shape_k2_out"], + name="Shape_k2" + ) + k_nodes.append(shape_k2) + shape_k3 = helper.make_node( + "Shape", + inputs=["present_value"], + outputs=["shape_k3_out"], + name="Shape_k3" + ) + k_nodes.append(shape_k3) + shape_k4 = helper.make_node( + "Shape", + inputs=["present_value"], + outputs=["shape_k4_out"], + name="Shape_k4" + ) + k_nodes.append(shape_k4) + gather_k_1 = helper.make_node( + "Gather", + inputs=["shape_k1_out", "one"], + outputs=["gather_k1_out"], + name="Gather_k_1", + axis=0, + ) + k_nodes.append(gather_k_1) + gather_k_2 = helper.make_node( + "Gather", + inputs=["shape_k2_out", "one"], + outputs=["gather_k2_out"], + name="Gather_k_2", + axis=0, + ) + k_nodes.append(gather_k_2) + gather_k_3 = helper.make_node( + "Gather", + inputs=["shape_k3_out", "one"], + outputs=["gather_k3_out"], + name="Gather_k_3", + axis=0, + ) + k_nodes.append(gather_k_3) + gather_k_4 = helper.make_node( + "Gather", + inputs=["shape_k4_out", "one"], + outputs=["gather_k4_out"], + name="Gather_k_4", + axis=0, + ) + k_nodes.append(gather_k_4) + unsqueeze_k_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_k1_out"], + name="Unsqueeze_k1", + ) + k_nodes.append(unsqueeze_k_1) + unsqueeze_k_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k2_out"], + name="Unsqueeze_k2", + ) + k_nodes.append(unsqueeze_k_2) + unsqueeze_k_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_k2_out", "zero"], + outputs=["unsqueeze_k3_out"], + name="Unsqueeze_k3", + ) + k_nodes.append(unsqueeze_k_3) + unsqueeze_k_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k4_out"], + name="Unsqueeze_k4", + ) + k_nodes.append(unsqueeze_k_4) + unsqueeze_k_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k5_out"], + name="Unsqueeze_k5", + ) + k_nodes.append(unsqueeze_k_5) + concat_k_2 = helper.make_node( "Concat", - inputs=["past_key", "k_rope"], - outputs=["present_key"], - axis=2, + inputs=["unsqueeze_k2_out", "unsqueeze_k3_out", "One", "unsqueeze_k4_out", "unsqueeze_k5_out"], + outputs=["concat_k2_ouot"], + name="Concat_k2", + axis=0, ) - k_nodes.append(concat_k_node) + k_nodes.append(concat_k_2) + reshape_k_2 = helper.make_node( + "Reshape", + inputs=["concat_k2_ouot", "One"], + outputs=["reshape_k2_out"], + name="Reshape_k_2", + ) + k_nodes.append(reshape_k_2) + shape_k5 = helper.make_node( + "Shape", + inputs=["reshape_k2_out"], + outputs=["shape_k5_out"], + name="Shape_k5" + ) + k_nodes.append(shape_k5) + constant_of_shape_k_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_k5_out"], + outputs=["constant_of_shape_k1_out"], + name="ConstantOfShape_k1" + ) + k_nodes.append(constant_of_shape_k_1) + mul_k_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_k1_out", "One"], + outputs=["mul_k1_out"], + name="mul_k1", + ) + k_nodes.append(mul_k_1) + equal_k_1 = helper.make_node( + "Equal", + inputs=["reshape_k2_out", "mul_k1_out"], + outputs=["equal_k_1_out"], + name="equal_k1", + ) + k_nodes.append(equal_k_1) + where_k_1 = helper.make_node( + "Where", + inputs=["equal_k_1_out", "constant_of_shape_k1_out", "reshape_k2_out"], + outputs=["where_k_1_out"], + name="where_k1", + ) + k_nodes.append(where_k_1) + unsqueeze_k_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k6_out"], + name="Unsqueeze_k6", + ) + k_nodes.append(unsqueeze_k_6) + mul_k_2 = helper.make_node( + "Mul", + inputs=["gather_k2_out", "One"], + outputs=["mul_k2_out"], + name="mul_k2", + ) + k_nodes.append(mul_k_2) + unsqueeze_k_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_k2_out", "zero"], + outputs=["unsqueeze_k7_out"], + name="Unsqueeze_k7", + ) + k_nodes.append(unsqueeze_k_7) + unsqueeze_k_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k8_out"], + name="Unsqueeze_k8", + ) + k_nodes.append(unsqueeze_k_8) + unsqueeze_k_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k9_out"], + name="Unsqueeze_k9", + ) + k_nodes.append(unsqueeze_k_9) + concat_k_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_k6_out", "unsqueeze_k7_out", "unsqueeze_k8_out", "unsqueeze_k9_out"], + outputs=["concat_k3_out"], + name="Concat_k3", + axis=0, + ) + k_nodes.append(concat_k_3) + expand_k_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_k1_out", "where_k_1_out"], + outputs=["expand_k1_out"], + name="expand_k1", + ) + k_nodes.append(expand_k_1) + reshape_k_3 = helper.make_node( + "Reshape", + inputs=["expand_k1_out", "concat_k3_out"], + outputs=["reshape_k3_out"], + name="Reshape_k_3", + ) + k_nodes.append(reshape_k_3) - transpose_k_2_node = helper.make_node( - "Transpose", - inputs=["present_key"], - outputs=["k"], - name="Transpose_k_2", - perm=[0, 1, 3, 2], - ) - return k_nodes + [transpose_k_2_node] # noqa: RUF005 + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["reshape_k3_out"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + return k_nodes + [transpose_k_2_node] # noqa: RUF005 + else: + if model_type in {"past", "merged"}: + concat_k_node = helper.make_node( + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) + k_nodes.append(concat_k_node) + + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + return k_nodes + [transpose_k_2_node] # noqa: RUF005 def create_k_path(self, model_type: str): if model_type == "llama2_msft": @@ -505,7 +726,7 @@ def create_v_path(self, model_type: str): if model_type == "no_past": return v_nodes - if model_type in {"past", "merged"}: + if model_type in {"past", "merged", "70b_distributed_merged"}: concat_v_node = helper.make_node( "Concat", inputs=["past_value", "transpose_v_1_out"], @@ -513,8 +734,216 @@ def create_v_path(self, model_type: str): name="Concat_v", axis=2, ) + + if model_type != "70b_distributed_merged": + return v_nodes + [concat_v_node] # noqa: RUF005 + + shape_v1 = helper.make_node( + "Shape", + inputs=["present_value"], + outputs=["shape_1_out"], + name="Shape_v1" + ) + v_nodes.append(shape_v1) + shape_v2 = helper.make_node( + "Shape", + inputs=["present_value"], + outputs=["shape_2_out"], + name="Shape_v2" + ) + v_nodes.append(shape_v2) + shape_v3 = helper.make_node( + "Shape", + inputs=["present_value"], + outputs=["shape_3_out"], + name="Shape_v3" + ) + v_nodes.append(shape_v3) + shape_v4 = helper.make_node( + "Shape", + inputs=["present_value"], + outputs=["shape_4_out"], + name="Shape_v4" + ) + v_nodes.append(shape_v4) + gather_v_1 = helper.make_node( + "Gather", + inputs=["shape_1_out", "one"], + outputs=["gather_1_out"], + name="Gather_v1", + axis=0, + ) + v_nodes.append(gather_v_1) + gather_v_2 = helper.make_node( + "Gather", + inputs=["shape_2_out", "one"], + outputs=["gather_2_out"], + name="Gather_v2", + axis=0, + ) + v_nodes.append(gather_v_2) + gather_v_3 = helper.make_node( + "Gather", + inputs=["shape_3_out", "one"], + outputs=["gather_3_out"], + name="Gather_v3", + axis=0, + ) + v_nodes.append(gather_v_3) + gather_v_4 = helper.make_node( + "Gather", + inputs=["shape_4_out", "one"], + outputs=["gather_4_out"], + name="Gather_v4", + axis=0, + ) + v_nodes.append(gather_v_4) + unsqueeze_v_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_v1_out"], + name="Unsqueeze_v1", + ) + v_nodes.append(unsqueeze_v_1) + unsqueeze_v_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v2_out"], + name="Unsqueeze_v2", + ) + v_nodes.append(unsqueeze_v_2) + unsqueeze_v_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_2_out", "zero"], + outputs=["unsqueeze_v3_out"], + name="Unsqueeze_v3", + ) + v_nodes.append(unsqueeze_v_3) + unsqueeze_v_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v4_out"], + name="Unsqueeze_v4", + ) + v_nodes.append(unsqueeze_v_4) + unsqueeze_v_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v5_out"], + name="Unsqueeze_v5", + ) + v_nodes.append(unsqueeze_v_5) + concat_v_2 = helper.make_node( + "Concat", + inputs=["unsqueeze_v2_out", "unsqueeze_v3_out", "One", "unsqueeze_v4_out", "unsqueeze_v5_out"], + outputs=["concat_v2_ouot"], + name="Concat_v2", + axis=0, + ) + v_nodes.append(concat_v_2) + reshape_v_2 = helper.make_node( + "Reshape", + inputs=["concat_v2_ouot", "One"], + outputs=["reshape_v2_out"], + name="Reshape_v2", + ) + v_nodes.append(reshape_v_2) + shape_v5 = helper.make_node( + "Shape", + inputs=["reshape_v2_out"], + outputs=["shape_5_out"], + name="Shape_v5" + ) + v_nodes.append(shape_v5) + constant_of_shape_v_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_5_out"], + outputs=["constant_of_shape_v1_out"], + name="ConstantOfShape_v1" + ) + v_nodes.append(constant_of_shape_v_1) + mul_v_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_v1_out", "One"], + outputs=["mul_v1_out"], + name="mul_v1", + ) + v_nodes.append(mul_v_1) + equal_v_1 = helper.make_node( + "Equal", + inputs=["reshape_v2_out", "mul_v1_out"], + outputs=["equal_v_1_out"], + name="equal_v1", + ) + v_nodes.append(equal_v_1) + where_v_1 = helper.make_node( + "Where", + inputs=["equal_v_1_out", "constant_of_shape_v1_out", "reshape_v2_out"], + outputs=["where_v_1_out"], + name="where_v1", + ) + v_nodes.append(where_v_1) + unsqueeze_v_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v6_out"], + name="Unsqueeze_v6", + ) + v_nodes.append(unsqueeze_v_6) + mul_v_2 = helper.make_node( + "Mul", + inputs=["gather_2_out", "One"], + outputs=["mul_v2_out"], + name="mul_v2", + ) + v_nodes.append(mul_v_2) + unsqueeze_v_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_v2_out", "zero"], + outputs=["unsqueeze_v7_out"], + name="Unsqueeze_v7", + ) + v_nodes.append(unsqueeze_v_7) + unsqueeze_v_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v8_out"], + name="Unsqueeze_v8", + ) + v_nodes.append(unsqueeze_v_8) + unsqueeze_v_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v9_out"], + name="Unsqueeze_v9", + ) + v_nodes.append(unsqueeze_v_9) + concat_v_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_v6_out", "unsqueeze_v7_out", "unsqueeze_v8_out", "unsqueeze_v9_out"], + outputs=["concat_v3_out"], + name="Concat_v3", + axis=0, + ) + v_nodes.append(concat_v_3) + expand_v_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_v1_out", "where_v_1_out"], + outputs=["expand_v1_out"], + name="expand_v1", + ) + v_nodes.append(expand_v_1) + reshape_v_3 = helper.make_node( + "Reshape", + inputs=["expand_v1_out", "concat_v3_out"], + outputs=["reshape_v3_out"], + name="Reshape_v3", + ) + v_nodes.append(reshape_v_3) + return v_nodes + [concat_v_node] # noqa: RUF005 + # Create extra nodes for `position_ids` unsqueeze_v_node = helper.make_node( "Unsqueeze", @@ -672,7 +1101,28 @@ def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: List[Nod return extra_nodes - def create_end_nodes(self): + def create_end_nodes(self, model_type): + if model_type == "70b_distributed_merged": + matmul_o_node = helper.make_node( + "MatMul", + inputs=["attn_output", "o_weight"], + outputs=["output_proj"], + name="MatMul_o_proj", + ) + all_reduce = helper.make_node( + "AllReduce", + inputs=["output_proj"], + outputs=["allreduce_proj"], + name="allreduce_proj", + ) + end_node = helper.make_node( + "Add", + inputs=["zero", "allreduce_proj"], + outputs=["output_0"], + name="Add_normalize_node", + ) + return [matmul_o_node, all_reduce, end_node] + matmul_o_node = helper.make_node( "MatMul", inputs=["attn_output", "o_weight"], @@ -711,7 +1161,7 @@ def create_fused_model(self, model_type: str, interleaved: bool, initializers: L num_heads=self.num_heads, ) - end_nodes = self.create_end_nodes() + end_nodes = self.create_end_nodes(model_type) graph = helper.make_graph( nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes, @@ -740,7 +1190,7 @@ def create_test_model(self, model_type: str, interleaved: bool, initializers: Li reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes)) extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes) - end_nodes = self.create_end_nodes() + end_nodes = self.create_end_nodes(model_type) first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes @@ -790,6 +1240,11 @@ def test_hf_decoder_merged_model(self): interleaved = False self.check_models(model_type, interleaved) + def test_hf_70b_distributed_decoder_merged_model(self): + model_type = "70b_distributed_merged" + interleaved = False + self.check_models(model_type, interleaved) + if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file From dd41176cf0da64680ec0c20e13b2ebd6b7dff947 Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Mon, 30 Oct 2023 22:29:49 +0000 Subject: [PATCH 02/13] fix bug in llama input --- .../python/tools/transformers/models/llama/llama_parity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index fb9b7f584edca..0511aad1d8783 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -50,7 +50,7 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True, world_size = world_size ) else: - inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True, world_size = world_size) + inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) return inputs From 920669364206d546a2296290956db52ce8a40ace Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Tue, 31 Oct 2023 03:10:48 +0000 Subject: [PATCH 03/13] lint style fix --- .../transformers/fusion_rotary_attention.py | 314 ++++++++++++++---- .../transformers/models/llama/benchmark.py | 18 +- .../models/llama/convert_to_onnx.py | 121 +++++-- .../models/llama/dist_settings.py | 5 +- .../transformers/models/llama/llama_inputs.py | 12 +- .../transformers/models/llama/llama_parity.py | 24 +- .../transformers/models/llama/llama_torch.py | 19 +- .../transformers/test_rotary_mha_fusion.py | 91 ++--- 8 files changed, 408 insertions(+), 196 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 84eef46d692b3..b2744adf9f8e6 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -338,7 +338,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qkv_nodes_3 = self.model.match_parent_path( normalize_node, ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"], - [1, 0, 0, 0, 0], + [1, 0, 0, 0, 0], ) if qkv_nodes_1 is not None: _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1 @@ -373,47 +373,117 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape", "MatMul"], [1, 0, 0], ) - _,v_nodes_4,_ = self.model.match_parent_paths_all( + _, v_nodes_4, _ = self.model.match_parent_paths_all( matmul_qkv, [ ( ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 0, 1, 0, 0], + [1, 0, 0, 0, 1, 0, 0], ), ( - ["Reshape", "Expand", "Where", "Equal", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], ), ( - ["Reshape", "Expand", "Where", "Equal", "Mul", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], ), ( - ["Reshape", "Expand", "Where", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0], + [ + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0], ), ( - ["Reshape", "Expand", "Where", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0], + [ + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0], ), ( ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 1, 0, 0, 0, 0, 1, 0, 0], + [1, 1, 0, 0, 0, 0, 1, 0, 0], ), ( - ["Reshape", "Concat", "Unsqueeze", "Mul", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 1, 1, 0, 0, 0, 0, 1, 0, 0], + [ + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 1, 1, 0, 0, 0, 0, 1, 0, 0], ), ( ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 1, 2, 0, 0, 0, 1, 0, 0], + [1, 1, 2, 0, 0, 0, 1, 0, 0], ), ( ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], - [1, 1, 3, 0, 0, 0, 1, 0, 0], + [1, 1, 3, 0, 0, 0, 1, 0, 0], ), ], - output_name_to_node = None + output_name_to_node=None, ) if v_nodes_1 is not None: reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 @@ -441,23 +511,23 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): v_nodes = v_nodes_3 present_v = transpose_v.output[0] elif v_nodes_4 is not None and len(v_nodes_4) == 9: - logger.debug('fuse_rotary_attention: v_nodes_4') - logger.debug('*' * 30) + logger.debug("fuse_rotary_attention: v_nodes_4") + logger.debug("*" * 30) for temp_path in v_nodes_4: - logger.debug('fuse_rotary_attention: path for v_nodes_4') + logger.debug("fuse_rotary_attention: path for v_nodes_4") for temp_node in temp_path: - logger.debug(f'temp_node: {temp_node}') - logger.debug('*' * 30) + logger.debug(f"temp_node: {temp_node}") + logger.debug("*" * 30) concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:] v_nodes = v_nodes_4 past_v = concat_v.input[0] present_v = concat_v.output[0] - logger.debug(f'transpose_v: {transpose_v}') - logger.debug(f'reshape_v: {reshape_v}') - logger.debug(f'matmul_v: {matmul_v}') - logger.debug(f'past_v: {past_v}') - logger.debug(f'present_v: {present_v}') + logger.debug(f"transpose_v: {transpose_v}") + logger.debug(f"reshape_v: {reshape_v}") + logger.debug(f"matmul_v: {matmul_v}") + logger.debug(f"past_v: {past_v}") + logger.debug(f"present_v: {present_v}") else: logger.debug("fuse_rotary_attention: failed to match v path") return @@ -537,43 +607,169 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): matmul_qk, [ ( - ["Transpose", "Reshape", "Expand", "Unsqueeze", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 0, 0, 1, 0, 0, 0], + [ + "Transpose", + "Reshape", + "Expand", + "Unsqueeze", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 0, 0, 1, 0, 0, 0], ), ( - ["Transpose", "Reshape", "Expand", "Where", "Equal", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], ), ( - ["Transpose", "Reshape", "Expand", "Where", "Equal", "Mul", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], ), ( - ["Transpose", "Reshape", "Expand", "Where", "ConstantOfShape", "Shape", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0], + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0], ), ( - ["Transpose", "Reshape", "Expand", "Where", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], - [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0], + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0], ), ( - ["Transpose", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], - [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0], + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0], ), ( - ["Transpose", "Reshape", "Concat", "Unsqueeze", "Mul", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], - [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0], + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0], ), ( - ["Transpose", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], - [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0], + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0], ), ( - ["Transpose", "Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], - [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0], + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0], ), ], - output_name_to_node = None + output_name_to_node=None, ) if k_nodes_1 is not None: reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 @@ -603,25 +799,25 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past_k = concat_k.input[0] present_k = concat_k.output[0] elif k_nodes_4 is not None and len(k_nodes_4) == 9: - logger.debug('fuse_rotary_attention: k_nodes_4') - logger.debug('*' * 30) + logger.debug("fuse_rotary_attention: k_nodes_4") + logger.debug("*" * 30) for temp_path in k_nodes_4: - logger.debug('fuse_rotary_attention: path for k_nodes_4') + logger.debug("fuse_rotary_attention: path for k_nodes_4") for temp_node in temp_path: - logger.debug(f'temp_node: {temp_node}') - logger.debug('*' * 30) + logger.debug(f"temp_node: {temp_node}") + logger.debug("*" * 30) reshape_k, matmul_k = k_nodes_4[0][-2:] concat_k, rotary_k = k_nodes_4[0][-5:-3] k_nodes = k_nodes_4 past_k = concat_k.input[0] present_k = concat_k.output[0] - logger.debug(f'reshape_k: {reshape_k}') - logger.debug(f'matmul_k: {matmul_k}') - logger.debug(f'concat_k: {concat_k}') - logger.debug(f'rotary_k: {rotary_k}') - logger.debug(f'past_k: {past_k}') - logger.debug(f'present_k: {present_k}') + logger.debug(f"reshape_k: {reshape_k}") + logger.debug(f"matmul_k: {matmul_k}") + logger.debug(f"concat_k: {concat_k}") + logger.debug(f"rotary_k: {rotary_k}") + logger.debug(f"past_k: {past_k}") + logger.debug(f"present_k: {present_k}") else: logger.debug("fuse_rotary_attention: failed to match k nodes") return @@ -669,7 +865,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return root_output = reshape_qkv_2.output[0] - elif qkv_nodes == qkv_nodes_2 or qkv_nodes == qkv_nodes_3: + elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3): if not self.check_runtime_shape_paths_for_nodes( reshape_qkv, reshape_q, @@ -1196,4 +1392,4 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): self.increase_counter(self.base_name) self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name self.nodes_to_add.append(rotary_emb_node) - self.prune_graph = True \ No newline at end of file + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index fe6f990a31741..996ea8264dca8 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,7 +11,7 @@ import onnx import psutil import torch -from onnxruntime.transformers.benchmark_helper import setup_logger +from dist_settings import get_rank, get_size from llama_inputs import ( convert_inputs_for_ort, get_merged_sample_with_past_kv_inputs, @@ -25,9 +25,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import measure_memory - -from dist_settings import get_rank, get_size +from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger logger = logging.getLogger(__name__) @@ -120,7 +118,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=0, use_fp16=args.use_fp16, return_dict=True, - world_size = args.world_size, + world_size=args.world_size, ) iter_inputs = get_merged_sample_with_past_kv_inputs( args.config, @@ -130,7 +128,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=args.sequence_length, use_fp16=args.use_fp16, return_dict=True, - world_size = args.world_size, + world_size=args.world_size, ) init_inputs = convert_inputs_for_ort( init_inputs, @@ -531,7 +529,7 @@ def run_inference(args, init_inputs, iter_inputs, model): raise Exception(f"Cannot recognize {args.benchmark_type}") -def get_args(rank = 0): +def get_args(rank=0): parser = argparse.ArgumentParser() parser.add_argument( "-bt", @@ -663,8 +661,8 @@ def main(): logger.info(args.__dict__) torch.backends.cudnn.benchmark = True - setattr(args, "rank", rank) - setattr(args, "world_size", world_size) + args.rank = rank + args.world_size = world_size tokenizer = LlamaTokenizer.from_pretrained(args.model_name) config = LlamaConfig.from_pretrained(args.model_name) target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device @@ -701,4 +699,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 41e743ff6045c..ab3b4e115ade0 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -2,16 +2,16 @@ import logging import os import shutil -import tempfile from itertools import chain from typing import List import onnx import torch -from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger from convert_generation import replace_mha_with_gqa +from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check +from llama_torch import setup_torch_model from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version @@ -19,13 +19,12 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer - -from dist_settings import init_dist, get_rank, get_size, barrier -from llama_torch import setup_torch_model +from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger logger = logging.getLogger("") init_dist() + def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): dynamic_axes = {} for name in input_names + output_names: @@ -133,7 +132,9 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st # del onnx_model # temp_dir.cleanup() # -def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int=0, world_size: int=1): +def run_dynamo_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): from torch._dynamo import config config.capture_scalar_outputs = True @@ -164,7 +165,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll # Export decoder_with_past_model.onnx input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length, world_size = world_size + l_config, device, batch_size, sequence_length, world_size=world_size ) temp_dir = args.output # tempfile.TemporaryDirectory() temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") @@ -186,11 +187,15 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") + def _prepare_dir(dir_path): if not os.path.exists(dir_path): os.makedirs(dir_path) -def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int=0, world_size: int=1): + +def run_torchscript_separate_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): # Dummy values for export batch_size, sequence_length = 2, 8 device = llama.device @@ -209,7 +214,7 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon # Avoid use system temp dir to avoid overflood on hard disk as 70b model is very large. # Use temp folder per rank to avoid race condition here. - temp_dir = f'./temp_{rank}' + temp_dir = f"./temp_{rank}" _prepare_dir(temp_dir) temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( @@ -240,7 +245,14 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon shutil.rmtree(temp_dir) # Export decoder_with_past_model.onnx - decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length, use_fp16 = args.precision == Precision.FLOAT16, world_size = world_size) + decoder_with_past_inputs = get_sample_with_past_kv_inputs( + l_config, + device, + batch_size, + sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, + ) input_names = [ "input_ids", "attention_mask", @@ -289,17 +301,27 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon del onnx_model shutil.rmtree(temp_dir) - logger.info(f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!") + logger.info( + f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!" + ) -def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int=0, world_size: int=1): +def run_torchscript_merged_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): # Dummy values for export batch_size, sequence_length, past_sequence_length = 2, 8, 0 device = llama.device # Export decoder_merged_model.onnx decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length, past_sequence_length, use_fp16 = args.precision == Precision.FLOAT16, world_size = world_size + l_config, + device, + batch_size, + sequence_length, + past_sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, ) input_names = [ "input_ids", @@ -321,7 +343,7 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi # Avoid use system temp dir to avoid overflood on hard disk as 70b model is very large. # Use temp folder per rank to avoid race condition here. - temp_dir = f'./temp_{rank}' + temp_dir = f"./temp_{rank}" _prepare_dir(temp_dir) temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( @@ -374,12 +396,16 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str): remove_existing_model(input_path) -def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str], rank: int=0, world_size: int=1): +def convert_to_float16( + args: argparse.Namespace, config: LlamaConfig, old_paths: List[str], rank: int = 0, world_size: int = 1 +): decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx") decoder_with_past_model_fp16_path = os.path.join( args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx" ) - decoder_merged_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx") + decoder_merged_model_fp16_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx" + ) new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path] logger.info("Converting to float16...") @@ -397,9 +423,11 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: return new_paths -def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int=1): +def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1): # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes - fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size) + fp16_model_opt = replace_mha_with_gqa( + fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size + ) fp16_model_opt.prune_graph() fp16_model_opt.update_graph(allow_remove_graph_inputs=True) return fp16_model_opt @@ -497,6 +525,7 @@ def remove_existing_files(output_path: str): os.remove(filepath) logger.warning(f"Removed {filepath}") + def get_args(): parser = argparse.ArgumentParser() @@ -714,11 +743,15 @@ def main(): for i in range(world_size): if i == rank: # Set model paths for FP32 model - decoder_model_fp32_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") + decoder_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx" + ) decoder_with_past_model_fp32_path = os.path.join( args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx" ) - decoder_merged_model_fp32_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx") + decoder_merged_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx" + ) old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] missing_separate_exports = ( @@ -747,14 +780,20 @@ def main(): run_torchscript_merged_export(args, l_config, llama, rank, world_size) # Set model paths to store FP32 optimized model - decoder_model_fp32_opt_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx") + decoder_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx" + ) decoder_with_past_model_fp32_opt_path = os.path.join( args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx" ) decoder_merged_model_fp32_opt_path = os.path.join( args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx" ) - new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path] + new_paths = [ + decoder_model_fp32_opt_path, + decoder_with_past_model_fp32_opt_path, + decoder_merged_model_fp32_opt_path, + ] # Run the optimizer script logger.info("Optimizing models...") @@ -777,18 +816,24 @@ def main(): _ = convert_to_float16(args, l_config, old_paths, rank, world_size) elif args.precision == Precision.INT8: - decoder_model_int8_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx") + decoder_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" + ) decoder_with_past_model_int8_path = os.path.join( args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" ) - decoder_merged_model_int8_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx") + decoder_merged_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" + ) new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] if args.quantization_method == "smooth_quant": if not args.no_merged: logger.error("SmoothQuant must be used on separately exported models") else: - logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8") + logger.info( + f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + ) smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) elif args.quantization_method == "quantize_dynamic": @@ -810,7 +855,9 @@ def main(): use_external_data_format=True, extra_options={"MatMulConstBOnly": True}, ) - logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!") + logger.info( + f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" + ) remove_existing_model(decoder_model_fp32_path) logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") @@ -822,11 +869,15 @@ def main(): if args.execution_provider != "cpu": old_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) - decoder_model_int4_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx") + decoder_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" + ) decoder_with_past_model_int4_path = os.path.join( args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" ) - decoder_merged_model_int4_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx") + decoder_merged_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" + ) new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] for fp_path, int4_path in zip(old_paths, new_paths): @@ -844,7 +895,6 @@ def main(): del llama barrier() - logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models @@ -856,7 +906,12 @@ def main(): # Verify parity on all saved ONNX models for filename in os.listdir(args.output): - if ".data" in filename or ".onnx" not in filename or args.precision not in filename or f'rank_{rank}' not in filename: + if ( + ".data" in filename + or ".onnx" not in filename + or args.precision not in filename + or f"rank_{rank}" not in filename + ): continue parity_cmd = [ @@ -869,7 +924,7 @@ def main(): "-fp", args.precision, "--cache_dir", - args.cache_dir + args.cache_dir, ] if "with_past" in filename: parity_cmd.append("--use_past_kv") @@ -877,11 +932,11 @@ def main(): parity_cmd.append("--merged") try: - logger.debug(f'check parity with cmd: {parity_cmd}') + logger.debug(f"check parity with cmd: {parity_cmd}") parity_check(parity_cmd) except Exception as e: logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py index 3040482612644..e0e03c70e8618 100644 --- a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -1,7 +1,7 @@ import os + import torch import torch.distributed as dist - from mpi4py import MPI @@ -23,6 +23,7 @@ def init_dist(): device = torch.device(local_rank) return device + comm = MPI.COMM_WORLD @@ -40,4 +41,4 @@ def barrier(): def print_out(*args): if get_rank() == 0: - print(*args) \ No newline at end of file + print(*args) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index bcededd463272..766500dd25086 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -46,13 +46,13 @@ def get_sample_with_past_kv_inputs( past_seq_len: int, use_fp16: bool = False, return_dict: bool = False, - world_size: int=1 + world_size: int = 1, ): input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64) attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64) # position_ids is of shape (batch_size, 1) position_ids = get_position_ids(attention_mask, use_past_kv=True) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16, world_size = world_size) + past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16, world_size=world_size) if not return_dict: return (input_ids, attention_mask, position_ids, past_kv) @@ -75,7 +75,7 @@ def get_merged_sample_with_past_kv_inputs( past_seq_len: int, use_fp16: bool = False, return_dict: bool = False, - world_size: int=1 + world_size: int = 1, ): input_ids = torch.randint( low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 @@ -83,7 +83,7 @@ def get_merged_sample_with_past_kv_inputs( attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16, world_size = world_size) + past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16, world_size=world_size) if not return_dict: return (input_ids, attention_mask, position_ids, past_kv) @@ -99,7 +99,7 @@ def get_merged_sample_with_past_kv_inputs( # Create past_key_values def get_sample_past_kv_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int=1 + config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1 ): num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 @@ -203,4 +203,4 @@ def get_msft_sample_inputs( } ) - return ort_inputs \ No newline at end of file + return ort_inputs diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 0511aad1d8783..fd4e4024f0891 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -7,6 +7,7 @@ import numpy as np import torch from benchmark_helper import setup_logger +from dist_settings import get_rank, get_size from llama_inputs import ( convert_inputs_for_ort, get_merged_sample_with_past_kv_inputs, @@ -14,7 +15,6 @@ get_sample_with_past_kv_inputs, ) from llama_torch import setup_torch_model -from dist_settings import get_rank, get_size from transformers import LlamaConfig, LlamaForCausalLM import onnxruntime as ort @@ -43,11 +43,17 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): past_sequence_length, use_fp16=args.use_fp16, return_dict=True, - world_size = world_size + world_size=world_size, ) elif args.use_past_kv: inputs = get_sample_with_past_kv_inputs( - config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True, world_size = world_size + config, + args.device, + batch_size, + sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + world_size=world_size, ) else: inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) @@ -82,7 +88,7 @@ def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, input def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): inputs = get_inputs(args, config) - logger.debug(f'torch input: {inputs}') + logger.debug(f"torch input: {inputs}") # Run inference with PyTorch if args.execution_provider != "cpu": @@ -107,7 +113,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama device_id=int(args.rank), ) - logger.debug(f'ORT input: {inputs}') + logger.debug(f"ORT input: {inputs}") ep = f"{args.execution_provider.upper()}ExecutionProvider" if ep == "CUDAExecutionProvider": @@ -247,13 +253,15 @@ def main(argv: List[str] = []): # noqa: B006 # Load model and config setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 - setattr(args, "rank", rank) + args.rank = rank setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 setattr(args, "device", torch.device(args.device_name)) # noqa: B010 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory - config, llama = setup_torch_model(args, location, use_auth_token, torch_dtype=(torch.float16 if args.use_fp16 else torch.float32)) + config, llama = setup_torch_model( + args, location, use_auth_token, torch_dtype=(torch.float16 if args.use_fp16 else torch.float32) + ) if not args.merged: verify_parity(args, config, llama) @@ -271,4 +279,4 @@ def main(argv: List[str] = []): # noqa: B006 seed = 2 np.random.seed(seed) torch.manual_seed(seed) - main() \ No newline at end of file + main() diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py index 312854de913ed..999f2f932d561 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -1,16 +1,16 @@ -import os import logging +import os import torch +from dist_settings import barrier, get_rank, get_size from transformers import LlamaConfig, LlamaForCausalLM -from dist_settings import get_rank, get_size, barrier - logger = logging.getLogger("") + def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, use_cuda=True): world_size = get_size() - logger.info(f'world_size: {world_size}') + logger.info(f"world_size: {world_size}") rank = get_rank() barrier() @@ -21,8 +21,13 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, if i == rank: l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) l_config.use_cache = True - llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, config=l_config, - torch_dtype=torch_dtype, cache_dir=args.cache_dir) + llama = LlamaForCausalLM.from_pretrained( + location, + use_auth_token=use_auth_token, + config=l_config, + torch_dtype=torch_dtype, + cache_dir=args.cache_dir, + ) if world_size > 1: llama.parallel_model() if use_cuda: @@ -30,4 +35,4 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, llama.eval() llama.requires_grad_(False) barrier() - return l_config, llama \ No newline at end of file + return l_config, llama diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py index 90e43246b7555..58a088e5716c2 100644 --- a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -297,40 +297,20 @@ def create_k_path_hf(self, model_type: str): if model_type == "70b_distributed_merged": concat_k_node = helper.make_node( - "Concat", - inputs=["past_key", "k_rope"], - outputs=["present_key"], - axis=2, - ) + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) k_nodes.append(concat_k_node) - shape_k1 = helper.make_node( - "Shape", - inputs=["present_value"], - outputs=["shape_k1_out"], - name="Shape_k1" - ) + shape_k1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k1_out"], name="Shape_k1") k_nodes.append(shape_k1) - shape_k2 = helper.make_node( - "Shape", - inputs=["present_value"], - outputs=["shape_k2_out"], - name="Shape_k2" - ) + shape_k2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k2_out"], name="Shape_k2") k_nodes.append(shape_k2) - shape_k3 = helper.make_node( - "Shape", - inputs=["present_value"], - outputs=["shape_k3_out"], - name="Shape_k3" - ) + shape_k3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k3_out"], name="Shape_k3") k_nodes.append(shape_k3) - shape_k4 = helper.make_node( - "Shape", - inputs=["present_value"], - outputs=["shape_k4_out"], - name="Shape_k4" - ) + shape_k4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k4_out"], name="Shape_k4") k_nodes.append(shape_k4) gather_k_1 = helper.make_node( "Gather", @@ -414,18 +394,13 @@ def create_k_path_hf(self, model_type: str): name="Reshape_k_2", ) k_nodes.append(reshape_k_2) - shape_k5 = helper.make_node( - "Shape", - inputs=["reshape_k2_out"], - outputs=["shape_k5_out"], - name="Shape_k5" - ) + shape_k5 = helper.make_node("Shape", inputs=["reshape_k2_out"], outputs=["shape_k5_out"], name="Shape_k5") k_nodes.append(shape_k5) constant_of_shape_k_1 = helper.make_node( "ConstantOfShape", inputs=["shape_k5_out"], outputs=["constant_of_shape_k1_out"], - name="ConstantOfShape_k1" + name="ConstantOfShape_k1", ) k_nodes.append(constant_of_shape_k_1) mul_k_1 = helper.make_node( @@ -737,34 +712,14 @@ def create_v_path(self, model_type: str): if model_type != "70b_distributed_merged": return v_nodes + [concat_v_node] # noqa: RUF005 - - shape_v1 = helper.make_node( - "Shape", - inputs=["present_value"], - outputs=["shape_1_out"], - name="Shape_v1" - ) + + shape_v1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_1_out"], name="Shape_v1") v_nodes.append(shape_v1) - shape_v2 = helper.make_node( - "Shape", - inputs=["present_value"], - outputs=["shape_2_out"], - name="Shape_v2" - ) + shape_v2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_2_out"], name="Shape_v2") v_nodes.append(shape_v2) - shape_v3 = helper.make_node( - "Shape", - inputs=["present_value"], - outputs=["shape_3_out"], - name="Shape_v3" - ) + shape_v3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_3_out"], name="Shape_v3") v_nodes.append(shape_v3) - shape_v4 = helper.make_node( - "Shape", - inputs=["present_value"], - outputs=["shape_4_out"], - name="Shape_v4" - ) + shape_v4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_4_out"], name="Shape_v4") v_nodes.append(shape_v4) gather_v_1 = helper.make_node( "Gather", @@ -848,18 +803,13 @@ def create_v_path(self, model_type: str): name="Reshape_v2", ) v_nodes.append(reshape_v_2) - shape_v5 = helper.make_node( - "Shape", - inputs=["reshape_v2_out"], - outputs=["shape_5_out"], - name="Shape_v5" - ) + shape_v5 = helper.make_node("Shape", inputs=["reshape_v2_out"], outputs=["shape_5_out"], name="Shape_v5") v_nodes.append(shape_v5) constant_of_shape_v_1 = helper.make_node( "ConstantOfShape", inputs=["shape_5_out"], outputs=["constant_of_shape_v1_out"], - name="ConstantOfShape_v1" + name="ConstantOfShape_v1", ) v_nodes.append(constant_of_shape_v_1) mul_v_1 = helper.make_node( @@ -940,9 +890,8 @@ def create_v_path(self, model_type: str): name="Reshape_v3", ) v_nodes.append(reshape_v_3) - - return v_nodes + [concat_v_node] # noqa: RUF005 + return v_nodes + [concat_v_node] # noqa: RUF005 # Create extra nodes for `position_ids` unsqueeze_v_node = helper.make_node( @@ -1247,4 +1196,4 @@ def test_hf_70b_distributed_decoder_merged_model(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From b30ff30b0148f363f012e442fe65405d6915a078 Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Tue, 31 Oct 2023 06:33:30 +0000 Subject: [PATCH 04/13] take comments --- .../tools/transformers/convert_generation.py | 8 +- .../python/tools/transformers/fusion_base.py | 5 + .../transformers/fusion_rotary_attention.py | 45 +----- .../tools/transformers/models/llama/README.md | 11 +- .../transformers/models/llama/benchmark.py | 39 +++-- .../models/llama/convert_to_onnx.py | 3 +- .../transformers/models/llama/llama_parity.py | 8 +- .../models/llama/requirements-70b-model.txt | 4 + .../models/llama/requirements.txt | 4 +- .../models/llama/{run.sh => run_70b_model.sh} | 2 +- ...{single_run.sh => single_run_70b_model.sh} | 0 .../transformers/test_rotary_mha_fusion.py | 133 ++++++++++-------- 12 files changed, 133 insertions(+), 129 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt rename onnxruntime/python/tools/transformers/models/llama/{run.sh => run_70b_model.sh} (84%) rename onnxruntime/python/tools/transformers/models/llama/{single_run.sh => single_run_70b_model.sh} (100%) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index be8eb6a4c42e8..7aca5e8526a23 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1282,6 +1282,10 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads # Replace MultiHeadAttention with GroupQueryAttention for node in model.model.graph.node: if node.op_type == "MultiHeadAttention": + num_heads_mha = 0 + for att in node.attribute: + if att.name == "num_heads": + num_heads_mha = att.i gqa_node = onnx.helper.make_node( "GroupQueryAttention", inputs=[ @@ -1295,8 +1299,8 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads outputs=node.output, name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), domain="com.microsoft", - num_heads=node.attribute[0].i // world_size, - kv_num_heads=node.attribute[0].i // world_size if kv_num_heads == 0 else kv_num_heads // world_size, + num_heads=num_heads_mha // world_size, + kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size, is_past_bsnh=0, ) model.model.graph.node.remove(node) diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index c5d7bc16d64f7..67f4f0b55cff8 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -130,3 +130,8 @@ def add_nodes_to_remove(self, nodes: List[NodeProto]): for node in nodes: if node not in self.nodes_to_remove: self.nodes_to_remove.append(node) + + def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]): + for node in nodes: + if node not in self.nodes_to_remove and node not in nodes_to_keep: + self.nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index b2744adf9f8e6..060fa0fba1615 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -511,23 +511,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): v_nodes = v_nodes_3 present_v = transpose_v.output[0] elif v_nodes_4 is not None and len(v_nodes_4) == 9: - logger.debug("fuse_rotary_attention: v_nodes_4") - logger.debug("*" * 30) - for temp_path in v_nodes_4: - logger.debug("fuse_rotary_attention: path for v_nodes_4") - for temp_node in temp_path: - logger.debug(f"temp_node: {temp_node}") - logger.debug("*" * 30) - concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:] v_nodes = v_nodes_4 past_v = concat_v.input[0] present_v = concat_v.output[0] - logger.debug(f"transpose_v: {transpose_v}") - logger.debug(f"reshape_v: {reshape_v}") - logger.debug(f"matmul_v: {matmul_v}") - logger.debug(f"past_v: {past_v}") - logger.debug(f"present_v: {present_v}") else: logger.debug("fuse_rotary_attention: failed to match v path") return @@ -585,7 +572,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # k_nodes_1 is for LLaMA-2 Microsoft # k_nodes_2 is for LLaMA-2 Hugging Face - # k_nodes_4 is for distributed LLaMA-2 Hugging Face + # k_nodes_4 is for LLaMA-2 70B Hugging Face past_k, present_k = "", "" k_nodes = None k_nodes_1 = self.model.match_parent_path( @@ -799,25 +786,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): past_k = concat_k.input[0] present_k = concat_k.output[0] elif k_nodes_4 is not None and len(k_nodes_4) == 9: - logger.debug("fuse_rotary_attention: k_nodes_4") - logger.debug("*" * 30) - for temp_path in k_nodes_4: - logger.debug("fuse_rotary_attention: path for k_nodes_4") - for temp_node in temp_path: - logger.debug(f"temp_node: {temp_node}") - logger.debug("*" * 30) - reshape_k, matmul_k = k_nodes_4[0][-2:] concat_k, rotary_k = k_nodes_4[0][-5:-3] k_nodes = k_nodes_4 past_k = concat_k.input[0] present_k = concat_k.output[0] - logger.debug(f"reshape_k: {reshape_k}") - logger.debug(f"matmul_k: {matmul_k}") - logger.debug(f"concat_k: {concat_k}") - logger.debug(f"rotary_k: {rotary_k}") - logger.debug(f"past_k: {past_k}") - logger.debug(f"present_k: {present_k}") else: logger.debug("fuse_rotary_attention: failed to match k nodes") return @@ -914,13 +887,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if v_nodes != v_nodes_4: self.nodes_to_remove.extend(v_nodes[:-1]) else: - remove_dic = {} - node_keep_name = v_nodes[0][-1].name + nodes_to_keep = [v_nodes[0][-1]] for temp_path in v_nodes: - for temp_node in temp_path: - if temp_node.name not in remove_dic and temp_node.name != node_keep_name: - remove_dic[temp_node.name] = temp_node - self.nodes_to_remove.extend(list(remove_dic.values())) + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) self.nodes_to_remove.extend(qk_nodes) @@ -936,13 +905,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.append(k_nodes[3]) self.nodes_to_remove.append(k_nodes[4]) elif k_nodes == k_nodes_4: - remove_dic = {} - node_keep_names = [k_nodes[0][-1].name, k_nodes[0][-4].name] + nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]] for temp_path in k_nodes: - for temp_node in temp_path: - if temp_node.name not in remove_dic and temp_node.name not in node_keep_names: - remove_dic[temp_node.name] = temp_node - self.nodes_to_remove.extend(list(remove_dic.values())) + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) if q_nodes == q_nodes_1: self.nodes_to_remove.extend(q_nodes[:-2]) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 7f484d659cb7b..98d5f419ef26c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -10,6 +10,8 @@ Please note the package versions needed for using LLaMA-2 in the `requirements.t - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. - `requirements-quant.txt` - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) +- `requirements-70b-model.txt` + - For run LLaMA-2 70B model in multiple GPUs - `requirements.txt` - Package versions needed in each of the above files @@ -153,13 +155,12 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu ``` -Export Sharded model, llama-70b into 4 partitions +Export LLaMA-2 70B sharded model into 4 partitions ``` # From source: -$ 1. Get OnnxRuntime code from https://github.com/frankdongms/transformers/tree/frdong/shard_llama or -$ wait until PR: https://github.com/huggingface/transformers/pull/27119 got merged into HF transformers -$ 2. Build OnnxRuntime from source with NCCL enabled, sample command: ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ -$ 3. Shard and export llama-70b model: CUDA_VISIBLE_DEVICES=0,1,2,3 bash run.sh 4 -m meta-llama/Llama-2-7b-hf --output llama2-7b-dis2 --precision fp16 --execution_provider cuda +# 1. Install necessary packages from requirements-70b-model.txt +# 2. Build Onnx Runtime from source with NCCL enabled. Here is an sample command: ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ +# 3. Shard and export the LLaMA-2 70B model. Here is an example command: CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda ``` ## Benchmark LLaMA-2 diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index 996ea8264dca8..d2a2c8bccfbc0 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -288,38 +288,57 @@ def time_fn(args, fn, inputs): outputs = fn(inputs) logger.info(outputs) + input_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_inputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + + output_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_outputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + for _ in warmup_range: + input_sync() fn(inputs) + output_sync() # Benchmark - if args.device != "cpu": - torch.cuda.synchronize() - start_time = time.time() - + total_time = 0 bench_range = ( range(args.num_runs) if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: + input_sync() + start_time = time.time() + fn(inputs) - if args.device != "cpu": - torch.cuda.synchronize() - end_time = time.time() + output_sync() + end_time = time.time() + + total_time += end_time - start_time # Newline print after trange in order to print metrics on new lines without progress bar on same line if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}: logger.info("") - latency = (end_time - start_time) / args.num_runs + latency = total_time / args.num_runs throughput = args.batch_size / latency if args.rank == 0: logger.info(f"Batch Size: {args.batch_size}") logger.info(f"Sequence Length: {args.sequence_length}") - logger.info(f"Latency: {latency:.4f} s") - logger.info(f"Throughput: {throughput:.4f} tps") + logger.info(f"Latency: {latency} s") + logger.info(f"Throughput: {throughput} tps") return diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index ab3b4e115ade0..0c19ddee3ac06 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -7,7 +7,7 @@ import onnx import torch -from convert_generation import replace_mha_with_gqa +from onnxruntime.transformers.convert_generation import replace_mha_with_gqa from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check @@ -273,6 +273,7 @@ def run_torchscript_separate_export( # Avoid use system temp dir to avoid overflood on hard disk as 70b model is very large. # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_past_{rank}" temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index fd4e4024f0891..83a3facc705f6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -88,8 +88,6 @@ def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, input def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): inputs = get_inputs(args, config) - logger.debug(f"torch input: {inputs}") - # Run inference with PyTorch if args.execution_provider != "cpu": torch.cuda.synchronize() @@ -113,8 +111,6 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama device_id=int(args.rank), ) - logger.debug(f"ORT input: {inputs}") - ep = f"{args.execution_provider.upper()}ExecutionProvider" if ep == "CUDAExecutionProvider": ep = (ep, {"device_id": args.rank}) @@ -128,10 +124,10 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama if args.execution_provider != "cpu": io_binding = add_io_bindings(args, ort_model, inputs) - torch.cuda.synchronize() + io_binding.synchronize_inputs() start_time = time.time() ort_model.run_with_iobinding(io_binding) - torch.cuda.synchronize() + io_binding.synchronize_outputs() end_time = time.time() ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt new file mode 100644 index 0000000000000..572cfdb71be4a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt @@ -0,0 +1,4 @@ +-r requirements.txt +git+https://github.com/frankdongms/transformers.git@frdong/shard_llama +mpi4py +psutil \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 5500bc0bdf2c1..4210f36982aef 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -3,6 +3,4 @@ transformers>=4.33.2 torch>=2.2.0.dev20230920 onnx>=1.14.0 datasets>=2.8.0 -protobuf==3.20.2 -mpi4py -psutil \ No newline at end of file +protobuf==3.20.2 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/run.sh b/onnxruntime/python/tools/transformers/models/llama/run_70b_model.sh similarity index 84% rename from onnxruntime/python/tools/transformers/models/llama/run.sh rename to onnxruntime/python/tools/transformers/models/llama/run_70b_model.sh index b980601c02252..1dfee5427ba56 100644 --- a/onnxruntime/python/tools/transformers/models/llama/run.sh +++ b/onnxruntime/python/tools/transformers/models/llama/run_70b_model.sh @@ -7,7 +7,7 @@ MPI="mpirun --allow-run-as-root --tag-output --npernode $NUM_GPUS --bind-to numa -x MIOPEN_FIND_MODE=1" -CMD="$MPI bash single_run.sh ${@:2}" +CMD="$MPI bash single_run_70b_model.sh ${@:2}" set -x $CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/single_run.sh b/onnxruntime/python/tools/transformers/models/llama/single_run_70b_model.sh similarity index 100% rename from onnxruntime/python/tools/transformers/models/llama/single_run.sh rename to onnxruntime/python/tools/transformers/models/llama/single_run_70b_model.sh diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py index 58a088e5716c2..373ad86ced1a7 100644 --- a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -302,16 +302,11 @@ def create_k_path_hf(self, model_type: str): outputs=["present_key"], axis=2, ) - k_nodes.append(concat_k_node) - shape_k1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k1_out"], name="Shape_k1") - k_nodes.append(shape_k1) shape_k2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k2_out"], name="Shape_k2") - k_nodes.append(shape_k2) shape_k3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k3_out"], name="Shape_k3") - k_nodes.append(shape_k3) shape_k4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k4_out"], name="Shape_k4") - k_nodes.append(shape_k4) + gather_k_1 = helper.make_node( "Gather", inputs=["shape_k1_out", "one"], @@ -319,7 +314,6 @@ def create_k_path_hf(self, model_type: str): name="Gather_k_1", axis=0, ) - k_nodes.append(gather_k_1) gather_k_2 = helper.make_node( "Gather", inputs=["shape_k2_out", "one"], @@ -327,7 +321,6 @@ def create_k_path_hf(self, model_type: str): name="Gather_k_2", axis=0, ) - k_nodes.append(gather_k_2) gather_k_3 = helper.make_node( "Gather", inputs=["shape_k3_out", "one"], @@ -335,7 +328,6 @@ def create_k_path_hf(self, model_type: str): name="Gather_k_3", axis=0, ) - k_nodes.append(gather_k_3) gather_k_4 = helper.make_node( "Gather", inputs=["shape_k4_out", "one"], @@ -343,42 +335,38 @@ def create_k_path_hf(self, model_type: str): name="Gather_k_4", axis=0, ) - k_nodes.append(gather_k_4) + unsqueeze_k_1 = helper.make_node( "Unsqueeze", inputs=["present_value", "zero"], outputs=["unsqueeze_k1_out"], name="Unsqueeze_k1", ) - k_nodes.append(unsqueeze_k_1) unsqueeze_k_2 = helper.make_node( "Unsqueeze", inputs=["gather_k1_out", "zero"], outputs=["unsqueeze_k2_out"], name="Unsqueeze_k2", ) - k_nodes.append(unsqueeze_k_2) unsqueeze_k_3 = helper.make_node( "Unsqueeze", inputs=["gather_k2_out", "zero"], outputs=["unsqueeze_k3_out"], name="Unsqueeze_k3", ) - k_nodes.append(unsqueeze_k_3) unsqueeze_k_4 = helper.make_node( "Unsqueeze", inputs=["gather_k3_out", "zero"], outputs=["unsqueeze_k4_out"], name="Unsqueeze_k4", ) - k_nodes.append(unsqueeze_k_4) unsqueeze_k_5 = helper.make_node( "Unsqueeze", inputs=["gather_k4_out", "zero"], outputs=["unsqueeze_k5_out"], name="Unsqueeze_k5", ) - k_nodes.append(unsqueeze_k_5) + concat_k_2 = helper.make_node( "Concat", inputs=["unsqueeze_k2_out", "unsqueeze_k3_out", "One", "unsqueeze_k4_out", "unsqueeze_k5_out"], @@ -386,79 +374,67 @@ def create_k_path_hf(self, model_type: str): name="Concat_k2", axis=0, ) - k_nodes.append(concat_k_2) reshape_k_2 = helper.make_node( "Reshape", inputs=["concat_k2_ouot", "One"], outputs=["reshape_k2_out"], name="Reshape_k_2", ) - k_nodes.append(reshape_k_2) shape_k5 = helper.make_node("Shape", inputs=["reshape_k2_out"], outputs=["shape_k5_out"], name="Shape_k5") - k_nodes.append(shape_k5) constant_of_shape_k_1 = helper.make_node( "ConstantOfShape", inputs=["shape_k5_out"], outputs=["constant_of_shape_k1_out"], name="ConstantOfShape_k1", ) - k_nodes.append(constant_of_shape_k_1) mul_k_1 = helper.make_node( "Mul", inputs=["constant_of_shape_k1_out", "One"], outputs=["mul_k1_out"], name="mul_k1", ) - k_nodes.append(mul_k_1) equal_k_1 = helper.make_node( "Equal", inputs=["reshape_k2_out", "mul_k1_out"], outputs=["equal_k_1_out"], name="equal_k1", ) - k_nodes.append(equal_k_1) where_k_1 = helper.make_node( "Where", inputs=["equal_k_1_out", "constant_of_shape_k1_out", "reshape_k2_out"], outputs=["where_k_1_out"], name="where_k1", ) - k_nodes.append(where_k_1) unsqueeze_k_6 = helper.make_node( "Unsqueeze", inputs=["gather_k1_out", "zero"], outputs=["unsqueeze_k6_out"], name="Unsqueeze_k6", ) - k_nodes.append(unsqueeze_k_6) mul_k_2 = helper.make_node( "Mul", inputs=["gather_k2_out", "One"], outputs=["mul_k2_out"], name="mul_k2", ) - k_nodes.append(mul_k_2) unsqueeze_k_7 = helper.make_node( "Unsqueeze", inputs=["mul_k2_out", "zero"], outputs=["unsqueeze_k7_out"], name="Unsqueeze_k7", ) - k_nodes.append(unsqueeze_k_7) unsqueeze_k_8 = helper.make_node( "Unsqueeze", inputs=["gather_k3_out", "zero"], outputs=["unsqueeze_k8_out"], name="Unsqueeze_k8", ) - k_nodes.append(unsqueeze_k_8) unsqueeze_k_9 = helper.make_node( "Unsqueeze", inputs=["gather_k4_out", "zero"], outputs=["unsqueeze_k9_out"], name="Unsqueeze_k9", ) - k_nodes.append(unsqueeze_k_9) concat_k_3 = helper.make_node( "Concat", inputs=["unsqueeze_k6_out", "unsqueeze_k7_out", "unsqueeze_k8_out", "unsqueeze_k9_out"], @@ -466,22 +442,18 @@ def create_k_path_hf(self, model_type: str): name="Concat_k3", axis=0, ) - k_nodes.append(concat_k_3) expand_k_1 = helper.make_node( "Expand", inputs=["unsqueeze_k1_out", "where_k_1_out"], outputs=["expand_k1_out"], name="expand_k1", ) - k_nodes.append(expand_k_1) reshape_k_3 = helper.make_node( "Reshape", inputs=["expand_k1_out", "concat_k3_out"], outputs=["reshape_k3_out"], name="Reshape_k_3", ) - k_nodes.append(reshape_k_3) - transpose_k_2_node = helper.make_node( "Transpose", inputs=["reshape_k3_out"], @@ -489,7 +461,41 @@ def create_k_path_hf(self, model_type: str): name="Transpose_k_2", perm=[0, 1, 3, 2], ) - return k_nodes + [transpose_k_2_node] # noqa: RUF005 + + k_nodes_for_70b_model = [ + concat_k_node, + shape_k1, + shape_k2, + shape_k3, + shape_k4, + gather_k_1, + gather_k_2, + gather_k_3, + gather_k_4, + unsqueeze_k_1, + unsqueeze_k_2, + unsqueeze_k_3, + unsqueeze_k_4, + unsqueeze_k_5, + concat_k_2, + reshape_k_2, + shape_k5, + constant_of_shape_k_1, + mul_k_1, + equal_k_1, + where_k_1, + unsqueeze_k_6, + mul_k_2, + unsqueeze_k_7, + unsqueeze_k_8, + unsqueeze_k_9, + concat_k_3, + expand_k_1, + reshape_k_3, + transpose_k_2_node, + ] + k_nodes.extend(k_nodes_for_70b_model) + return k_nodes else: if model_type in {"past", "merged"}: concat_k_node = helper.make_node( @@ -714,13 +720,9 @@ def create_v_path(self, model_type: str): return v_nodes + [concat_v_node] # noqa: RUF005 shape_v1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_1_out"], name="Shape_v1") - v_nodes.append(shape_v1) shape_v2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_2_out"], name="Shape_v2") - v_nodes.append(shape_v2) shape_v3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_3_out"], name="Shape_v3") - v_nodes.append(shape_v3) shape_v4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_4_out"], name="Shape_v4") - v_nodes.append(shape_v4) gather_v_1 = helper.make_node( "Gather", inputs=["shape_1_out", "one"], @@ -728,7 +730,6 @@ def create_v_path(self, model_type: str): name="Gather_v1", axis=0, ) - v_nodes.append(gather_v_1) gather_v_2 = helper.make_node( "Gather", inputs=["shape_2_out", "one"], @@ -736,7 +737,6 @@ def create_v_path(self, model_type: str): name="Gather_v2", axis=0, ) - v_nodes.append(gather_v_2) gather_v_3 = helper.make_node( "Gather", inputs=["shape_3_out", "one"], @@ -744,7 +744,6 @@ def create_v_path(self, model_type: str): name="Gather_v3", axis=0, ) - v_nodes.append(gather_v_3) gather_v_4 = helper.make_node( "Gather", inputs=["shape_4_out", "one"], @@ -752,42 +751,36 @@ def create_v_path(self, model_type: str): name="Gather_v4", axis=0, ) - v_nodes.append(gather_v_4) unsqueeze_v_1 = helper.make_node( "Unsqueeze", inputs=["present_value", "zero"], outputs=["unsqueeze_v1_out"], name="Unsqueeze_v1", ) - v_nodes.append(unsqueeze_v_1) unsqueeze_v_2 = helper.make_node( "Unsqueeze", inputs=["gather_1_out", "zero"], outputs=["unsqueeze_v2_out"], name="Unsqueeze_v2", ) - v_nodes.append(unsqueeze_v_2) unsqueeze_v_3 = helper.make_node( "Unsqueeze", inputs=["gather_2_out", "zero"], outputs=["unsqueeze_v3_out"], name="Unsqueeze_v3", ) - v_nodes.append(unsqueeze_v_3) unsqueeze_v_4 = helper.make_node( "Unsqueeze", inputs=["gather_3_out", "zero"], outputs=["unsqueeze_v4_out"], name="Unsqueeze_v4", ) - v_nodes.append(unsqueeze_v_4) unsqueeze_v_5 = helper.make_node( "Unsqueeze", inputs=["gather_4_out", "zero"], outputs=["unsqueeze_v5_out"], name="Unsqueeze_v5", ) - v_nodes.append(unsqueeze_v_5) concat_v_2 = helper.make_node( "Concat", inputs=["unsqueeze_v2_out", "unsqueeze_v3_out", "One", "unsqueeze_v4_out", "unsqueeze_v5_out"], @@ -795,79 +788,67 @@ def create_v_path(self, model_type: str): name="Concat_v2", axis=0, ) - v_nodes.append(concat_v_2) reshape_v_2 = helper.make_node( "Reshape", inputs=["concat_v2_ouot", "One"], outputs=["reshape_v2_out"], name="Reshape_v2", ) - v_nodes.append(reshape_v_2) shape_v5 = helper.make_node("Shape", inputs=["reshape_v2_out"], outputs=["shape_5_out"], name="Shape_v5") - v_nodes.append(shape_v5) constant_of_shape_v_1 = helper.make_node( "ConstantOfShape", inputs=["shape_5_out"], outputs=["constant_of_shape_v1_out"], name="ConstantOfShape_v1", ) - v_nodes.append(constant_of_shape_v_1) mul_v_1 = helper.make_node( "Mul", inputs=["constant_of_shape_v1_out", "One"], outputs=["mul_v1_out"], name="mul_v1", ) - v_nodes.append(mul_v_1) equal_v_1 = helper.make_node( "Equal", inputs=["reshape_v2_out", "mul_v1_out"], outputs=["equal_v_1_out"], name="equal_v1", ) - v_nodes.append(equal_v_1) where_v_1 = helper.make_node( "Where", inputs=["equal_v_1_out", "constant_of_shape_v1_out", "reshape_v2_out"], outputs=["where_v_1_out"], name="where_v1", ) - v_nodes.append(where_v_1) unsqueeze_v_6 = helper.make_node( "Unsqueeze", inputs=["gather_1_out", "zero"], outputs=["unsqueeze_v6_out"], name="Unsqueeze_v6", ) - v_nodes.append(unsqueeze_v_6) mul_v_2 = helper.make_node( "Mul", inputs=["gather_2_out", "One"], outputs=["mul_v2_out"], name="mul_v2", ) - v_nodes.append(mul_v_2) unsqueeze_v_7 = helper.make_node( "Unsqueeze", inputs=["mul_v2_out", "zero"], outputs=["unsqueeze_v7_out"], name="Unsqueeze_v7", ) - v_nodes.append(unsqueeze_v_7) unsqueeze_v_8 = helper.make_node( "Unsqueeze", inputs=["gather_3_out", "zero"], outputs=["unsqueeze_v8_out"], name="Unsqueeze_v8", ) - v_nodes.append(unsqueeze_v_8) unsqueeze_v_9 = helper.make_node( "Unsqueeze", inputs=["gather_4_out", "zero"], outputs=["unsqueeze_v9_out"], name="Unsqueeze_v9", ) - v_nodes.append(unsqueeze_v_9) concat_v_3 = helper.make_node( "Concat", inputs=["unsqueeze_v6_out", "unsqueeze_v7_out", "unsqueeze_v8_out", "unsqueeze_v9_out"], @@ -875,23 +856,53 @@ def create_v_path(self, model_type: str): name="Concat_v3", axis=0, ) - v_nodes.append(concat_v_3) expand_v_1 = helper.make_node( "Expand", inputs=["unsqueeze_v1_out", "where_v_1_out"], outputs=["expand_v1_out"], name="expand_v1", ) - v_nodes.append(expand_v_1) reshape_v_3 = helper.make_node( "Reshape", inputs=["expand_v1_out", "concat_v3_out"], outputs=["reshape_v3_out"], name="Reshape_v3", ) - v_nodes.append(reshape_v_3) - return v_nodes + [concat_v_node] # noqa: RUF005 + v_nodes_for_70b_model = [ + concat_v_node, + shape_v1, + shape_v2, + shape_v3, + shape_v4, + gather_v_1, + gather_v_2, + gather_v_3, + gather_v_4, + unsqueeze_v_1, + unsqueeze_v_2, + unsqueeze_v_3, + unsqueeze_v_4, + unsqueeze_v_5, + concat_v_2, + reshape_v_2, + shape_v5, + constant_of_shape_v_1, + mul_v_1, + equal_v_1, + where_v_1, + unsqueeze_v_6, + mul_v_2, + unsqueeze_v_7, + unsqueeze_v_8, + unsqueeze_v_9, + concat_v_3, + expand_v_1, + reshape_v_3, + ] + v_nodes.extend(v_nodes_for_70b_model) + + return v_nodes # Create extra nodes for `position_ids` unsqueeze_v_node = helper.make_node( From c4649e8a0ec002008b7f48569496f19e105da359 Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Tue, 31 Oct 2023 06:51:06 +0000 Subject: [PATCH 05/13] handle export device separately for different model --- .../python/tools/transformers/models/llama/README.md | 8 +++++++- .../transformers/models/llama/convert_to_onnx.py | 12 ++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 98d5f419ef26c..c52abae1fdd99 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -86,8 +86,14 @@ Export Saved Model on Disk # From source: $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b +# From source using first gpu: +$ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b + # From wheel: $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b + +# From wheel using second gpu: +$ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b ``` Export for FP32 CUDA @@ -160,7 +166,7 @@ Export LLaMA-2 70B sharded model into 4 partitions # From source: # 1. Install necessary packages from requirements-70b-model.txt # 2. Build Onnx Runtime from source with NCCL enabled. Here is an sample command: ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ -# 3. Shard and export the LLaMA-2 70B model. Here is an example command: CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda +# 3. Shard and export the LLaMA-2 70B model, we will need at least 4 A100 GPUs to do shard Pytorch model and export each shardding to ONNX. Here is an example command: CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda ``` ## Benchmark LLaMA-2 diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 0c19ddee3ac06..443ede00e46ef 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -198,7 +198,11 @@ def run_torchscript_separate_export( ): # Dummy values for export batch_size, sequence_length = 2, 8 - device = llama.device + + # set device used to export model + # for llama-70b we will use current gpus to speed up export process (we need at least 4 A100 GPUs) + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "meta-llama/Llama-2-70b-hf" else torch.device("cpu") # Export decoder_model.onnx decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length) @@ -312,7 +316,11 @@ def run_torchscript_merged_export( ): # Dummy values for export batch_size, sequence_length, past_sequence_length = 2, 8, 0 - device = llama.device + + # set device used to export model + # for llama-70b we will use current gpus to speed up export process (we need at least 4 A100 GPUs) + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "meta-llama/Llama-2-70b-hf" else torch.device("cpu") # Export decoder_merged_model.onnx decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( From f393cb2a8611a066b3a0bd0cc6e084c18685d3da Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Tue, 31 Oct 2023 07:17:09 +0000 Subject: [PATCH 06/13] lint fix --- .../python/tools/transformers/models/llama/convert_to_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 443ede00e46ef..597935d17e113 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -7,7 +7,6 @@ import onnx import torch -from onnxruntime.transformers.convert_generation import replace_mha_with_gqa from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check @@ -20,6 +19,7 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger +from onnxruntime.transformers.convert_generation import replace_mha_with_gqa logger = logging.getLogger("") init_dist() @@ -316,7 +316,7 @@ def run_torchscript_merged_export( ): # Dummy values for export batch_size, sequence_length, past_sequence_length = 2, 8, 0 - + # set device used to export model # for llama-70b we will use current gpus to speed up export process (we need at least 4 A100 GPUs) # for other models, we will use CPU to make sure we have enough memory to do export From fdd4c5fd60c533f1d4f1843159b33f14845b92cf Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Tue, 31 Oct 2023 08:23:45 +0000 Subject: [PATCH 07/13] take comments --- .../tools/transformers/models/llama/README.md | 26 ++++++++++++------- .../transformers/models/llama/benchmark.py | 1 + .../models/llama/convert_to_onnx.py | 5 ++-- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index c52abae1fdd99..8470e2801ed8c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -11,7 +11,7 @@ Please note the package versions needed for using LLaMA-2 in the `requirements.t - `requirements-quant.txt` - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) - `requirements-70b-model.txt` - - For run LLaMA-2 70B model in multiple GPUs + - For running the LLaMA-2 70B model on multiple GPUs - `requirements.txt` - Package versions needed in each of the above files @@ -81,19 +81,20 @@ model.save_pretrained(name.split("/")[-1] + "-onnx") Here are some additional examples for exporting LLaMA-2. -Export Saved Model on Disk +Export Model with Different GPU Device Ids ``` +# From source using first GPU: +$ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b + +# From wheel using second GPU: +$ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b +Export Saved Model on Disk + # From source: $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b -# From source using first gpu: -$ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b - # From wheel: $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b - -# From wheel using second gpu: -$ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b ``` Export for FP32 CUDA @@ -165,8 +166,13 @@ Export LLaMA-2 70B sharded model into 4 partitions ``` # From source: # 1. Install necessary packages from requirements-70b-model.txt -# 2. Build Onnx Runtime from source with NCCL enabled. Here is an sample command: ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ -# 3. Shard and export the LLaMA-2 70B model, we will need at least 4 A100 GPUs to do shard Pytorch model and export each shardding to ONNX. Here is an example command: CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda + +# 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command: +$ ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ + +# 3. Shard and export the LLaMA-2 70B model. You will need at least 4 A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command: +$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda + ``` ## Benchmark LLaMA-2 diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index d2a2c8bccfbc0..c0e9a2955efd1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -490,6 +490,7 @@ def prepare_ort_inputs(inputs): else: io_binding.bind_output(name, device_type=args.device, device_id=args.rank) + setattr(args, "io_binding", io_binding) # noqa: B010 return io_binding return inputs diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 597935d17e113..febf0c20115a1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -822,7 +822,7 @@ def main(): # Change precision of exported models from FP32 if args.precision == Precision.FLOAT16: - _ = convert_to_float16(args, l_config, old_paths, rank, world_size) + new_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) elif args.precision == Precision.INT8: decoder_model_int8_path = os.path.join( @@ -909,7 +909,8 @@ def main(): # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models args.precision = ( "fp32" - if args.precision in {"int8", "fp32"} or (args.precision == Precision.INT4 and args.execution_provider == "cpu") + if args.precision in {Precision.INT8, Precision.FLOAT32} + or (args.precision == Precision.INT4 and args.execution_provider == "cpu") else "fp16" ) From 79e185dc7fc738e5996de4b2cc96ae71dee9a91e Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Tue, 31 Oct 2023 18:53:38 +0000 Subject: [PATCH 08/13] take comments --- .../tools/transformers/models/llama/README.md | 2 +- .../transformers/models/llama/convert_to_onnx.py | 14 +++++++------- .../transformers/models/llama/llama_parity.py | 2 +- .../tools/transformers/models/llama/llama_torch.py | 4 ++-- .../transformers/models/llama/run_70b_model.sh | 3 +-- .../models/llama/single_run_70b_model.sh | 5 ----- 6 files changed, 12 insertions(+), 18 deletions(-) delete mode 100644 onnxruntime/python/tools/transformers/models/llama/single_run_70b_model.sh diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 8470e2801ed8c..5e42ae9b5c890 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -170,7 +170,7 @@ Export LLaMA-2 70B sharded model into 4 partitions # 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command: $ ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ -# 3. Shard and export the LLaMA-2 70B model. You will need at least 4 A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command: +# 3. Shard and export the LLaMA-2 70B model. LLaMA-2 70B has 70 billion parameters, with fp16 you will need at least 140GB GPU memory in total to load model weight, so you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command: $ CUDA_VISIBLE_DEVICES=0,1,2,3 bash run_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda ``` diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index febf0c20115a1..cdb0ec69c7b9c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -200,9 +200,9 @@ def run_torchscript_separate_export( batch_size, sequence_length = 2, 8 # set device used to export model - # for llama-70b we will use current gpus to speed up export process (we need at least 4 A100 GPUs) + # for llama-2-70b we will use current gpus to speed up export process (we need at least 4 A100 GPUs) # for other models, we will use CPU to make sure we have enough memory to do export - device = llama.device if args.model_name == "meta-llama/Llama-2-70b-hf" else torch.device("cpu") + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") # Export decoder_model.onnx decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length) @@ -216,7 +216,7 @@ def run_torchscript_separate_export( ] dynamic_axes = get_model_dynamic_axes(input_names, output_names) - # Avoid use system temp dir to avoid overflood on hard disk as 70b model is very large. + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. # Use temp folder per rank to avoid race condition here. temp_dir = f"./temp_{rank}" _prepare_dir(temp_dir) @@ -275,7 +275,7 @@ def run_torchscript_separate_export( ] dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names) - # Avoid use system temp dir to avoid overflood on hard disk as 70b model is very large. + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. # Use temp folder per rank to avoid race condition here. temp_dir = f"./temp_past_{rank}" temp_path = os.path.join(temp_dir, "temp.onnx") @@ -318,9 +318,9 @@ def run_torchscript_merged_export( batch_size, sequence_length, past_sequence_length = 2, 8, 0 # set device used to export model - # for llama-70b we will use current gpus to speed up export process (we need at least 4 A100 GPUs) + # for llama-2-70b we will use current gpus to speed up export process (we need at least 4 A100 GPUs) # for other models, we will use CPU to make sure we have enough memory to do export - device = llama.device if args.model_name == "meta-llama/Llama-2-70b-hf" else torch.device("cpu") + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") # Export decoder_merged_model.onnx decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( @@ -350,7 +350,7 @@ def run_torchscript_merged_export( ] dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names) - # Avoid use system temp dir to avoid overflood on hard disk as 70b model is very large. + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. # Use temp folder per rank to avoid race condition here. temp_dir = f"./temp_{rank}" _prepare_dir(temp_dir) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 83a3facc705f6..e49453019f95e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -6,7 +6,7 @@ import numpy as np import torch -from benchmark_helper import setup_logger +from onnxruntime.transformers.benchmark_helper import setup_logger from dist_settings import get_rank, get_size from llama_inputs import ( convert_inputs_for_ort, diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py index 999f2f932d561..ae0a1b02c2b5d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -17,8 +17,8 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, if not os.path.exists(args.cache_dir): os.makedirs(args.cache_dir) - for i in range(world_size): - if i == rank: + for i in range(world_size // 2): + if i == rank % (world_size // 2): l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) l_config.use_cache = True llama = LlamaForCausalLM.from_pretrained( diff --git a/onnxruntime/python/tools/transformers/models/llama/run_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/run_70b_model.sh index 1dfee5427ba56..637d15c10e0c7 100644 --- a/onnxruntime/python/tools/transformers/models/llama/run_70b_model.sh +++ b/onnxruntime/python/tools/transformers/models/llama/run_70b_model.sh @@ -7,7 +7,6 @@ MPI="mpirun --allow-run-as-root --tag-output --npernode $NUM_GPUS --bind-to numa -x MIOPEN_FIND_MODE=1" -CMD="$MPI bash single_run_70b_model.sh ${@:2}" +CMD="$MPI python convert_to_onnx.py ${@:2}" -set -x $CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/single_run_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/single_run_70b_model.sh deleted file mode 100644 index 82f34f79b7bf4..0000000000000 --- a/onnxruntime/python/tools/transformers/models/llama/single_run_70b_model.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -CMD="python convert_to_onnx.py ${@}" - -$CMD \ No newline at end of file From 0093725a3cff2c67dd2ed701f00d7b4da14b443e Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Tue, 31 Oct 2023 19:04:09 +0000 Subject: [PATCH 09/13] fix comments --- .../python/tools/transformers/models/llama/llama_parity.py | 2 +- .../python/tools/transformers/models/llama/llama_torch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index e49453019f95e..a579aecd6eb45 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -6,7 +6,6 @@ import numpy as np import torch -from onnxruntime.transformers.benchmark_helper import setup_logger from dist_settings import get_rank, get_size from llama_inputs import ( convert_inputs_for_ort, @@ -18,6 +17,7 @@ from transformers import LlamaConfig, LlamaForCausalLM import onnxruntime as ort +from onnxruntime.transformers.benchmark_helper import setup_logger logger = logging.getLogger("") diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py index ae0a1b02c2b5d..78b7b9b1339dd 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -15,7 +15,7 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, barrier() if not os.path.exists(args.cache_dir): - os.makedirs(args.cache_dir) + os.makedirs(args.cache_dir, exist_ok=True) for i in range(world_size // 2): if i == rank % (world_size // 2): From dcddb283b071580d174f2801e7e371ecd8b0c15c Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Thu, 2 Nov 2023 06:00:34 +0000 Subject: [PATCH 10/13] fix benchmark_all and change dist_settings mpi is only required for multi-gpus --- .../models/llama/benchmark_all.py | 11 +------ .../models/llama/dist_settings.py | 30 +++++++++---------- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py index 951b2549368f7..b35a5e27f9ea3 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -247,6 +247,7 @@ def main(): torch.backends.cudnn.benchmark = True all_results = [] + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) # Benchmark PyTorch without torch.compile if args.hf_pt_eager: @@ -266,8 +267,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -298,8 +297,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -332,8 +329,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -366,8 +361,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -399,8 +392,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py index e0e03c70e8618..13591cdc6538d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -2,41 +2,41 @@ import torch import torch.distributed as dist -from mpi4py import MPI +comm = None + def init_dist(): if "LOCAL_RANK" in os.environ: local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank) elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: + from mpi4py import MPI + comm = MPI.COMM_WORLD + local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) - else: - local_rank = 0 - rank = 0 - world_size = 1 - - dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank) - device = torch.device(local_rank) - return device - - -comm = MPI.COMM_WORLD + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank) + else: + # don't need to do init for single process + pass def get_rank(): - return comm.Get_rank() + return comm.Get_rank() if comm is not None else 0 def get_size(): - return comm.Get_size() + return comm.Get_size() if comm is not None else 1 def barrier(): - comm.Barrier() + if comm is not None: + comm.Barrier() def print_out(*args): From aa8ffb68a2d41de4fa16ba2703dc3989ca340d82 Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Thu, 2 Nov 2023 06:08:09 +0000 Subject: [PATCH 11/13] fix lint error --- .../tools/transformers/models/llama/dist_settings.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py index 13591cdc6538d..50b0669d6d83a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -1,23 +1,23 @@ import os -import torch import torch.distributed as dist - comm = None + def init_dist(): if "LOCAL_RANK" in os.environ: - local_rank = int(os.environ["LOCAL_RANK"]) + int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank) elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: from mpi4py import MPI - comm = MPI.COMM_WORLD - local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) + comm = MPI.COMM_WORLD # noqa: F841 + + int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) @@ -26,6 +26,7 @@ def init_dist(): # don't need to do init for single process pass + def get_rank(): return comm.Get_rank() if comm is not None else 0 From 4c567ca01826adb2dd3b1e677d875dc1ee060c31 Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Thu, 2 Nov 2023 06:37:16 +0000 Subject: [PATCH 12/13] fix readme --- onnxruntime/python/tools/transformers/models/llama/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 3e107a2636f07..07094201b574a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -91,6 +91,7 @@ $ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.conver ``` Export Saved Model on Disk +``` # From source: $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b From 31c011a11eae6d562ab5f9b8bf15efa9dfe88120 Mon Sep 17 00:00:00 2001 From: Frank Dong Date: Thu, 2 Nov 2023 06:41:35 +0000 Subject: [PATCH 13/13] remove extra line --- onnxruntime/python/tools/transformers/models/llama/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 07094201b574a..1bb6940d1cd74 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -92,7 +92,6 @@ $ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.conver Export Saved Model on Disk ``` - # From source: $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b