Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Feb 23, 2024
1 parent 78cac04 commit 3b4de7a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def run_optimize_phi2_onnx(
[p.start() for p in processes]
[p.join() for p in processes]

if args.run_example:
if args.run_example or args.run_benchmark:
from inference_example import run_phi2

if args.fp16_gpu_sm8x:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,20 @@ def get_initial_inputs_and_outputs(self, encodings_dict):
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["step"]
)

seqlens_k = (
torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["seqlens_k"]
)
cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))

total_seq_length = (
torch.tensor([sequence_length], device=torch.device("cpu"), dtype=torch.int32)
torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["total_sequence_length"]
)
total_seq_length[0] = sequence_length

inputs = {
"input_ids": input_ids.contiguous(),
Expand Down Expand Up @@ -191,8 +195,8 @@ def create_session(
):
self.device_id = device_id
sess_options = ort.SessionOptions()
# sess_options.log_verbosity_level = 0
# sess_options.log_severity_level = 0
sess_options.log_verbosity_level = 4
sess_options.log_severity_level = 4
self.use_cuda_graph = use_cuda_graph
ep = (
("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
Expand All @@ -211,21 +215,27 @@ def create_session(
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
self.tokenizer.pad_token = "[PAD]"

def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation):
def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)

print(inputs["input_ids"])

all_token_ids = inputs["input_ids"].clone()
batch_size, sequence_length = all_token_ids.shape

current_length = sequence_length
has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)

if benchmark:
import time

latency = []

prompt_run = True
while current_length < max_length:
io_binding = self.apply_io_binding(self.sess, inputs, outputs)

if benchmark:
start = time.time()

io_binding.synchronize_inputs()
if prompt_run:
if self.use_cuda_graph:
Expand All @@ -242,6 +252,10 @@ def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation):
self.sess.run_with_iobinding(io_binding, self.ro)
io_binding.synchronize_outputs()

if benchmark:
end = time.time()
latency.append(end - start)

# Sample with argmax (greedy search)
next_token_logits = outputs["logits"][:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
Expand Down Expand Up @@ -273,9 +287,8 @@ def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation):
inputs["step"] = self.static_inputs_map[batch_size]["step"]

if self.use_cuda_graph:
inputs["seqlens_k"] = torch.tensor(
batch_size * [current_length - 1], device=self.device, dtype=torch.int32
)
previous_seqlens_k = inputs["seqlens_k"]
inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
inputs["total_sequence_length"][0] = current_length
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
Expand Down Expand Up @@ -314,8 +327,14 @@ def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation):
{f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()}
) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()})

texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
return texts
if benchmark:
print(
f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
)
print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
else:
texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
return texts

def generate(self, prompt, max_length, cuda_graph_annotation):
encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
Expand All @@ -330,7 +349,11 @@ def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()

return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
# Warm up run
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)

# Benchmark run
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)


def run_phi2(
Expand All @@ -350,7 +373,7 @@ def simple_run(prompt):
example_batch_size = len(prompt)
if use_cuda_graph:
generator.append_static_inputs(batch_size=example_batch_size)
texts = generator.generate(prompt, max_length=100, cuda_graph_annotation=example_batch_size)
texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)

for i in range(len(texts)):
print("Prompt: ", prompt[i])
Expand All @@ -361,18 +384,17 @@ def simple_run(prompt):
def print_prime(n):
"""
Print all primes between 1 and n
"""''', "Give an example of using ONNX Runtime to run a model.",
"""'''
]

simple_run([prompt[0]])
# bugbug: batch 2 has different result
simple_run(prompt)
simple_run([prompt[1]])
if not run_benchmark:
simple_run(prompt)

# Run simple benchmark. Time the decoder only.
if run_benchmark:
token_num = 256
for batch_size in [1, 2, 4, 8, 16]:
token_num = 32
for batch_size in [2]:
generator.append_static_inputs(batch_size)
for sequence_length in [16, 64, 256, 1024]:
for sequence_length in [16]:
prompt_shape = (batch_size, sequence_length)
texts = generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)
generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)

0 comments on commit 3b4de7a

Please sign in to comment.