Skip to content

Commit

Permalink
Refactor the LLM benchmark script (#427)
Browse files Browse the repository at this point in the history
* Refactor load json file

* Extract main profile function

* Refactor + allow users to specify parameter through JSON when generating
prompt

* Minor fix

* Fix sample output

* Add postfix to filenames and update doc

* Fix sample output
  • Loading branch information
nv-hwoo authored Oct 28, 2023
1 parent 85dd6f6 commit ab8ac0c
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 92 deletions.
261 changes: 182 additions & 79 deletions src/c++/perf_analyzer/docs/examples/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,78 @@
import argparse
import json
import subprocess
from dataclasses import dataclass
from itertools import pairwise
from pathlib import Path
from typing import Optional

import numpy as np

TEMP_INPUT_FILE = "temp_input_data.json"
INPUT_FILENAME = "generated_input_data.json"

TITLE = "\n[ BENCHMARK SUMMARY ]\n"
PROMPT_SIZE = " Prompt size: {}"
FIRST_TOKEN_LATENCY = "Average first-token latency: {:.4f} sec"
T2T_LATENCY = "Average total token-to-token latency: {:.4f} sec"

def load_profile_data():
with open("profile_export.json") as f:

@dataclass
class ProfileResults:
prompt_size: int
avg_first_token_latency: int
avg_total_t2t_latency: int
avg_periodic_t2t_latencies: Optional[list[int]] = None


def load_json_data(filename):
with open(filename) as f:
return json.load(f)


def print_benchmark_summary(results):
output = "\n[ Benchmark Summary ]"
for prompt_size, avg_first_token_latency, avg_token_to_token_latency in results:
output += (
f"\n Prompt size: {prompt_size}, "
f"Average first-token latency: {avg_first_token_latency:.4f} sec"
)
output += (
f", Average token-to-token latency: {avg_token_to_token_latency:.4f} sec"
if avg_token_to_token_latency
else ""
)
print(output)
def save_json_data(data, filename):
with open(filename, "w") as f:
json.dump(data, f)


def get_postfix(args, prompt_size):
"""Generate postfix for profile export filename and plot.
e.g.
- trtllm-prompt100-maxtokens256
- trtllm-prompt100-periodic1_100_1-period32-maxtokens1024
"""
postfix = f"{args.model}-prompt{prompt_size}-"
if args.periodic_concurrency_range:
start, end, step = args.periodic_concurrency_range
postfix += f"periodic{start}_{end}_{step}-period{args.request_period}-"
postfix += f"maxtokens{args.max_tokens}"
return postfix


def get_export_filename(args, prompt_size):
postfix = get_postfix(args, prompt_size)
filename = f"profile_export-{postfix}.json"
return filename


def get_plot_filename(args, prompt_size):
postfix = get_postfix(args, prompt_size)
filename = f"inflight_batching_benchmark-{postfix}.png"
return filename


def plot_results(latencies):
def print_benchmark_summary(profile_results):
output = [TITLE]
for pr in profile_results:
line = [PROMPT_SIZE.format(pr.prompt_size)]
line += [FIRST_TOKEN_LATENCY.format(pr.avg_first_token_latency)]
if pr.avg_total_t2t_latency:
line += [T2T_LATENCY.format(pr.avg_total_t2t_latency)]
output += [", ".join(line) + "\n"]
print("".join(output))


def plot_results(latencies, filename="inflight_batching_benchmark.png"):
"""Plot in-flight batching LLM bencharmark results."""
import matplotlib.pyplot as plt # Lazy import

Expand All @@ -70,8 +113,7 @@ def plot_results(latencies):
ax.set_title("In-Flight Batching Benchmark Summary", fontsize=14)
ax.set_ylim(bottom=0.0)

fig.savefig("inflight_batching_benchmark.png", dpi=300)
print("Saved benchmark result @ 'inflight_batching_benchmark.png'.")
fig.savefig(filename, dpi=300)


def add_latencies_to_bins(bins, pos, responses, request_period):
Expand Down Expand Up @@ -101,7 +143,7 @@ def update_start_position(request_id, start_pos, initial_requests, step):
return start_pos


def collect_periodic_latencies(args):
def collect_periodic_latencies(args, filename):
"""Split the entire benchmark results into segments with size
of request period and collect latencies for each segment.
"""
Expand All @@ -114,7 +156,7 @@ def collect_periodic_latencies(args):
bins = [[] for _ in range(num_bins)]
bin_start_position = 0

data = load_profile_data()
data = load_json_data(filename)
requests = data["experiments"][0]["requests"]

for i, r in enumerate(requests):
Expand All @@ -133,11 +175,9 @@ def collect_periodic_latencies(args):
return bins


def calculate_avg_periodic_latencies(args):
"""Calculate average token-to-token latency for each
request period.
"""
bins = collect_periodic_latencies(args)
def calculate_avg_periodic_latencies(args, filename):
"""Calculate average token-to-token latency for each request period."""
bins = collect_periodic_latencies(args, filename)

latencies = []
for bin in bins:
Expand All @@ -161,8 +201,9 @@ def collect_latencies(requests):
return first_token_latencies, token_to_token_latencies


def calculate_avg_latencies():
requests = load_profile_data()
def calculate_avg_latencies(filename):
"""Calculate avg first-token and avg total token-to-token latencies."""
requests = load_json_data(filename)
first_token_latencies, token_to_token_latencies = collect_latencies(requests)

# Compute mean and convert from nanosec to sec
Expand All @@ -174,15 +215,45 @@ def calculate_avg_latencies():
return avg_first_token_latency, avg_token_to_token_latency


def profile(args, input_data_file):
# Clean up
export_file = Path("profile_export.json")
export_file.unlink(missing_ok=True)
def summarize_profile_results(args, prompts):
results = []
for prompt in prompts:
prompt_size = len(prompt.split())
export_file = get_export_filename(args, prompt_size)
avg_first_token_latency, avg_total_t2t_latency = calculate_avg_latencies(
filename=export_file
)

profile_result = ProfileResults(
prompt_size=prompt_size,
avg_first_token_latency=avg_first_token_latency,
avg_total_t2t_latency=avg_total_t2t_latency,
)

if args.periodic_concurrency_range:
periodic_latencies = calculate_avg_periodic_latencies(args, export_file)
profile_result.avg_periodic_t2t_latencies = periodic_latencies
plot_results(
latencies=periodic_latencies,
filename=get_plot_filename(args, prompt_size),
)

results.append(profile_result)

print_benchmark_summary(results)

if args.periodic_concurrency_range:
print(
"Saved in-flight batching benchmark plots "
"@ 'inflight_batching_benchmark-*.png'."
)


def profile(args, export_file):
command = (
f"perf_analyzer -m {args.model} -i grpc --async --streaming "
f"--input-data={input_data_file} "
"--profile-export-file=profile_export.json "
f"--input-data={INPUT_FILENAME} "
f"--profile-export-file={export_file} "
)
if args.periodic_concurrency_range:
start, end, step = args.periodic_concurrency_range
Expand All @@ -196,23 +267,87 @@ def profile(args, input_data_file):
"--measurement-request-count=10 "
"--stability-percentage=999"
)

print("Running Perf Analyzer...")
subprocess.run(args=[command], shell=True)


def generate_input_data(args, prompt_size, filename):
request_parameters = f"""
{{
"max_tokens": {args.max_tokens},
"ignore_eos": {"true" if args.ignore_eos else "false"}
}}
def prepare_export_file(args, prompt):
prompt_size = len(prompt.split())
filename = get_export_filename(args, prompt_size)

# If exists, clean up
export_file = Path(filename)
export_file.unlink(missing_ok=True)
return export_file


def prepare_input_data(input_data, prompt):
"""Insert the prompt to send into input JSON data."""
input_data["data"][0]["PROMPT"] = [prompt]
save_json_data(input_data, INPUT_FILENAME)


def generate_prompts(args, input_data):
"""Generate dummy prompts if not specified by input JSON file."""
prompt = input_data["data"][0]["PROMPT"][0]

if not prompt: # Generate dummy prompt
assert args.prompt_size_range, "Must specify --prompt-size-range."
start, end, step = args.prompt_size_range
return [" ".join(["hi"] * size) for size in range(start, end + 1, step)]
return [prompt]


def construct_input_data(args):
"""Construct input data that contains input tensors and parameters.
Parse the input JSON file (if exists) to construct the input data.
When user sets parameters through command line, overwrite the
parameters set by input JSON file.
"""
input_data = {"data": [{"STREAM": [True]}]}
input_data["data"][0]["SAMPLING_PARAMETERS"] = [request_parameters]
prompt = ""
stream = True
sampling_params = {}

prompt = ["hi"] * prompt_size # Generate dummy prompt
input_data["data"][0]["PROMPT"] = [" ".join(prompt)]
with open(filename, "w") as f:
json.dump(input_data, f)
if args.input_data:
data = load_json_data(filename=args.input_data)["data"][0]
stream = data["STREAM"][0] if "STREAM" in data else stream
prompt = data["PROMPT"][0] if "PROMPT" in data else prompt
if "SAMPLING_PARAMETERS" in data:
sampling_params = json.loads(data["SAMPLING_PARAMETERS"][0])

# If specified, overwrite max_tokens
if args.max_tokens:
sampling_params["max_tokens"] = args.max_tokens
else:
args.max_tokens = sampling_params["max_tokens"]

# If specified, overwrite ignore_eos
if "ignore_eos" not in sampling_params:
sampling_params["ignore_eos"] = args.ignore_eos
elif args.ignore_eos:
sampling_params["ignore_eos"] = True

input_data = {"data": [{}]}
input_data["data"][0]["PROMPT"] = [prompt]
input_data["data"][0]["STREAM"] = [stream]
input_data["data"][0]["SAMPLING_PARAMETERS"] = [json.dumps(sampling_params)]
return input_data


def main(args):
input_data = construct_input_data(args)
prompts = generate_prompts(args, input_data)

for prompt in prompts:
prepare_input_data(input_data, prompt)
export_file = prepare_export_file(args, prompt)

# Run Perf Analyzer
profile(args, export_file)

summarize_profile_results(args, prompts)


if __name__ == "__main__":
Expand All @@ -229,7 +364,6 @@ def generate_input_data(args, prompt_size, filename):
type=int,
nargs=3,
metavar=("START", "END", "STEP"),
default=[10, 10, 1],
help="The range of prompt sizes '<[START, END], STEP>' where END is inclusive.",
)
parser.add_argument(
Expand All @@ -248,7 +382,6 @@ def generate_input_data(args, prompt_size, filename):
parser.add_argument(
"--max-tokens",
type=int,
default=256,
help="The maximum number of tokens to generate.",
)
parser.add_argument(
Expand All @@ -262,34 +395,4 @@ def generate_input_data(args, prompt_size, filename):
help="The input data file to be used for inference request.",
)
args = parser.parse_args()

results = []

if args.input_data:
print(f"Using input data file '{args.input_data}' for inference request.\n")
with open(args.input_data) as f:
input_data = json.load(f)
prompt_size = len(input_data["data"][0]["PROMPT"][0].split())
args.prompt_size_range = [prompt_size, prompt_size, 1]

start, end, step = args.prompt_size_range
for prompt_size in range(start, end + 1, step):
if not args.input_data:
generate_input_data(args, prompt_size, TEMP_INPUT_FILE)

profile(args, args.input_data if args.input_data else TEMP_INPUT_FILE)

if not args.periodic_concurrency_range:
(
avg_first_token_latency,
avg_token_to_token_latency,
) = calculate_avg_latencies()
results.append(
(prompt_size, avg_first_token_latency, avg_token_to_token_latency)
)

if args.periodic_concurrency_range:
avg_latencies = calculate_avg_periodic_latencies(args)
plot_results(avg_latencies)
else:
print_benchmark_summary(results)
main(args)
Loading

0 comments on commit ab8ac0c

Please sign in to comment.