From 3b4de7a56d4e992d69b6e425739379e43f7472b0 Mon Sep 17 00:00:00 2001 From: Your Date: Fri, 23 Feb 2024 23:52:21 +0000 Subject: [PATCH] update --- .../models/phi2/convert_to_onnx.py | 2 +- .../models/phi2/inference_example.py | 66 ++++++++++++------- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index 8e25cec724b68..85caafd1c59cd 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -519,7 +519,7 @@ def run_optimize_phi2_onnx( [p.start() for p in processes] [p.join() for p in processes] - if args.run_example: + if args.run_example or args.run_benchmark: from inference_example import run_phi2 if args.fp16_gpu_sm8x: diff --git a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py index 013911ec40872..173c4e37dd7e7 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py +++ b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py @@ -83,16 +83,20 @@ def get_initial_inputs_and_outputs(self, encodings_dict): if not self.use_traced_inputs else self.static_inputs_map[batch_size]["step"] ) + seqlens_k = ( torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32) if not self.use_traced_inputs else self.static_inputs_map[batch_size]["seqlens_k"] ) + cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32)) + total_seq_length = ( - torch.tensor([sequence_length], device=torch.device("cpu"), dtype=torch.int32) + torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32) if not self.use_traced_inputs else self.static_inputs_map[batch_size]["total_sequence_length"] ) + total_seq_length[0] = sequence_length inputs = { "input_ids": input_ids.contiguous(), @@ -191,8 +195,8 @@ def create_session( ): self.device_id = device_id sess_options = ort.SessionOptions() - # sess_options.log_verbosity_level = 0 - # sess_options.log_severity_level = 0 + sess_options.log_verbosity_level = 4 + sess_options.log_severity_level = 4 self.use_cuda_graph = use_cuda_graph ep = ( ("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph}) @@ -211,21 +215,27 @@ def create_session( self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) self.tokenizer.pad_token = "[PAD]" - def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation): + def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False): inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict) - print(inputs["input_ids"]) - all_token_ids = inputs["input_ids"].clone() batch_size, sequence_length = all_token_ids.shape current_length = sequence_length has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool) + if benchmark: + import time + + latency = [] + prompt_run = True while current_length < max_length: io_binding = self.apply_io_binding(self.sess, inputs, outputs) + if benchmark: + start = time.time() + io_binding.synchronize_inputs() if prompt_run: if self.use_cuda_graph: @@ -242,6 +252,10 @@ def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation): self.sess.run_with_iobinding(io_binding, self.ro) io_binding.synchronize_outputs() + if benchmark: + end = time.time() + latency.append(end - start) + # Sample with argmax (greedy search) next_token_logits = outputs["logits"][:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1) @@ -273,9 +287,8 @@ def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation): inputs["step"] = self.static_inputs_map[batch_size]["step"] if self.use_cuda_graph: - inputs["seqlens_k"] = torch.tensor( - batch_size * [current_length - 1], device=self.device, dtype=torch.int32 - ) + previous_seqlens_k = inputs["seqlens_k"] + inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32) inputs["total_sequence_length"][0] = current_length if self.use_traced_inputs: cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"]) @@ -314,8 +327,14 @@ def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation): {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()} ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) - texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True) - return texts + if benchmark: + print( + f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}" + ) + print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms") + else: + texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True) + return texts def generate(self, prompt, max_length, cuda_graph_annotation): encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) @@ -330,7 +349,11 @@ def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation): encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist() encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist() - return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation) + # Warm up run + self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False) + + # Benchmark run + self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True) def run_phi2( @@ -350,7 +373,7 @@ def simple_run(prompt): example_batch_size = len(prompt) if use_cuda_graph: generator.append_static_inputs(batch_size=example_batch_size) - texts = generator.generate(prompt, max_length=100, cuda_graph_annotation=example_batch_size) + texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size) for i in range(len(texts)): print("Prompt: ", prompt[i]) @@ -361,18 +384,17 @@ def simple_run(prompt): def print_prime(n): """ Print all primes between 1 and n - """''', "Give an example of using ONNX Runtime to run a model.", + """''' ] - simple_run([prompt[0]]) - # bugbug: batch 2 has different result - simple_run(prompt) - simple_run([prompt[1]]) + if not run_benchmark: + simple_run(prompt) + # Run simple benchmark. Time the decoder only. if run_benchmark: - token_num = 256 - for batch_size in [1, 2, 4, 8, 16]: + token_num = 32 + for batch_size in [2]: generator.append_static_inputs(batch_size) - for sequence_length in [16, 64, 256, 1024]: + for sequence_length in [16]: prompt_shape = (batch_size, sequence_length) - texts = generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size) + generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)