Skip to content

Commit

Permalink
Introduce LLMProfileData to parse and aggregate LLM performance stati…
Browse files Browse the repository at this point in the history
…stics (#480)

* Add LLMProfileData to parse and aggregate LLM performance statistics

* Address feedback

* Shift from unittest to pytest

* Fix copyright
  • Loading branch information
nv-hwoo authored Mar 2, 2024
1 parent 1c56a46 commit 9e763d3
Show file tree
Hide file tree
Showing 3 changed files with 356 additions and 0 deletions.
187 changes: 187 additions & 0 deletions src/c++/perf_analyzer/genai-pa/genai_pa/llm_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from dataclasses import dataclass
from itertools import pairwise

import numpy as np
from genai_pa.utils import load_json
from transformers import AutoTokenizer


@dataclass
class LLMMetrics:
"""A simple dataclass that holds core LLM performance metrics."""

time_to_first_tokens: list[int]
inter_token_latencies: list[int]
output_token_throughputs: list[int]

def get_base_name(self, attr_name: str) -> str:
# Attempted to extract and store the mapping as a dataclass member as a
# dictionary but encountered two issues: (1) Python does not allow
# dataclass member to be mutable and (2) if we set it as member of
# normal class, the dict member will be parsed by Statistics class,
# which is not what we want since it's not one of the LLM metrics.
# Leaving it as conditional statements for now.
if attr_name == "time_to_first_tokens":
return "time_to_first_token"
elif attr_name == "inter_token_latencies":
return "inter_token_latency"
elif attr_name == "output_token_throughputs":
return "output_token_throughput"
else:
raise ValueError(f"No attribute named '{attr_name}' exists.")


class Statistics:
"""A class that aggregates various statistics from given metrics class.
The Statistics class goes through each metric in the metrics class and
calculates several statistics such as:
- average (arithmetic mean)
- percentiles (p25, p50, p75, p90, p95, p99)
- minimum & maximum
- standard deviation
The class will store each calculated statistics as part of its attribute.
Example:
>>> metrics = LLMMetrics([3, 4, 5], [], [])
>>> stats = Statistics(metrics)
>>> print(stats.avg_time_to_first_token) # output: 4
"""

def __init__(self, metrics: LLMMetrics):
# iterate through LLMMetrics to calculate statistics and set attributes
for attr, data in metrics.__dict__.items():
if data:
attr = metrics.get_base_name(attr)
self._calculate_mean(data, attr)
self._calculate_percentiles(data, attr)
self._calculate_minmax(data, attr)
self._calculate_std(data, attr)

def _calculate_mean(self, data: list[int], attr: str):
avg = np.mean(data)
setattr(self, "avg_" + attr, avg)

def _calculate_percentiles(self, data: list[int], attr: str):
p25, p50, p75 = np.percentile(data, [25, 50, 75])
p90, p95, p99 = np.percentile(data, [90, 95, 99])
setattr(self, "p25_" + attr, p25)
setattr(self, "p50_" + attr, p50)
setattr(self, "p75_" + attr, p75)
setattr(self, "p90_" + attr, p90)
setattr(self, "p95_" + attr, p95)
setattr(self, "p99_" + attr, p99)

def _calculate_minmax(self, data: list[int], attr: str):
min, max = np.min(data), np.max(data)
setattr(self, "min_" + attr, min)
setattr(self, "max_" + attr, max)

def _calculate_std(self, data: list[int], attr: str):
std = np.std(data)
setattr(self, "std_" + attr, std)

def __repr__(self):
attr_strs = ",".join([f"{k}={v}" for k, v in self.__dict__.items()])
return f"Statistics({attr_strs})"


class LLMProfileData:
"""A class that calculates and aggregates all the LLM performance statistics
across the Perf Analyzer profile results.
The LLMProfileData class parses profile export JSON file, collects the core
LLM performance metrics, and calculates summary statistics for each different
Perf Analyzer runs/experiments.
Example:
>>> ... # run Perf Analyzer with concurrency level 10
>>>
>>> from transformers import AutoTokenizer
>>>
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> pd = LLMProfileData(filename="profile_export.json", tokenizer)
>>> stats = pd.get_statistics(infer_mode="concurrency", level=10)
>>>
>>> print(stats) # output: Statistics(avg_time_to_first_token=...)
"""

def __init__(self, filename: str, tokenizer: AutoTokenizer) -> None:
data = load_json(filename)
self._profile_results = {}

for experiment in data["experiments"]:
infer_mode = experiment["experiment"]["mode"]
load_level = experiment["experiment"]["value"]
requests = experiment["requests"]

metrics = self._collect_llm_metrics(requests, tokenizer)

# aggregate and calculate statistics
statistics = Statistics(metrics)
self._profile_results[(infer_mode, load_level)] = statistics

def _collect_llm_metrics(
self, requests: dict, tokenizer: AutoTokenizer
) -> LLMMetrics:
time_to_first_tokens = []
inter_token_latencies = []
output_token_throughputs = []
for request in requests:
req_timestamp = request["timestamp"]
res_timestamps = request["response_timestamps"]
res_outputs = request["response_outputs"]

# time to first token
time_to_first_tokens.append(res_timestamps[0] - req_timestamp)

# output token throughput
output_tokens = tokenizer(res_outputs)["input_ids"]
total_output_tokens = np.sum(list(map(len, output_tokens)))
req_latency = res_timestamps[-1] - req_timestamp
output_token_throughputs.append(total_output_tokens / req_latency)

# inter token latency
for t1, t2 in pairwise(res_timestamps):
inter_token_latencies.append(t2 - t1)

return LLMMetrics(
time_to_first_tokens,
inter_token_latencies,
output_token_throughputs,
)

def get_statistics(self, infer_mode: str, load_level: int | float) -> Statistics:
if (infer_mode, load_level) not in self._profile_results:
raise KeyError(f"Profile with {infer_mode}={load_level} does not exist.")
return self._profile_results[(infer_mode, load_level)]
6 changes: 6 additions & 0 deletions src/c++/perf_analyzer/genai-pa/genai_pa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
from pathlib import Path


def load_json(filename: str):
with open(filename) as f:
return json.load(f)


def remove_file(file: Path):
if file.is_file():
file.unlink()
Expand Down
163 changes: 163 additions & 0 deletions src/c++/perf_analyzer/genai-pa/tests/test_llm_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
from pathlib import Path

import numpy as np
import pytest
from genai_pa.llm_profile import LLMMetrics, LLMProfileData
from genai_pa.utils import remove_file
from transformers import AutoTokenizer


class TestLLMProfileData:
@pytest.fixture
def prepare_profile_data(self) -> None:
self.path = Path("temp_profile_export.json")
self.profile_data = {
"experiments": [
{
"experiment": {
"mode": "concurrency",
"value": 10,
},
"requests": [
{
"timestamp": 1,
"response_timestamps": [3, 5, 8],
"response_outputs": ["dogs", "are", "cool"],
},
{
"timestamp": 2,
"response_timestamps": [4, 7, 11],
"response_outputs": ["I", "don't", "cook food"],
},
],
},
{
"experiment": {
"mode": "request_rate",
"value": 2.0,
},
"requests": [
{
"timestamp": 5,
"response_timestamps": [7, 8, 13, 18],
"response_outputs": ["cats", "are", "cool", "too"],
},
{
"timestamp": 3,
"response_timestamps": [6, 8, 11],
"response_outputs": ["it's", "very", "simple work"],
},
],
},
],
}

with open(self.path, "w") as f:
json.dump(self.profile_data, f)

yield None

# clean up
remove_file(self.path)

def test_llm_profile_data(self, prepare_profile_data) -> None:
"""Collect LLM metrics from profile export data and check values.
Metrics
* time to first tokens
- experiment 1: [3 - 1, 4 - 2] = [2, 2]
- experiment 2: [7 - 5, 6 - 3] = [2, 3]
* inter token latencies
- experiment 1: [5 - 3, 8 - 5, 7 - 4, 10 - 7] = [2, 3, 3, 4]
- experiment 2: [8 - 7, 13 - 8, 18 - 13, 8 - 6, 11 - 8] = [1, 5, 5, 2, 3]
* output token throughputs
- experiment 1: [3/(8 - 1), 5/(11 - 2)] = [3/7, 5/9]
- experiment 2: [4/(18 - 5), 5/(11 - 3)] = [4/13, 5/8]
"""
tokenizer = AutoTokenizer.from_pretrained("gpt2")
pd = LLMProfileData("temp_profile_export.json", tokenizer)

# experiment 1 statistics
stat = pd.get_statistics(infer_mode="concurrency", load_level=10)
assert stat.avg_time_to_first_token == 2
assert stat.avg_inter_token_latency == 3
assert stat.avg_output_token_throughput == pytest.approx(31 / 63)
assert stat.p50_time_to_first_token == 2
assert stat.p50_inter_token_latency == 3
assert stat.p50_output_token_throughput == pytest.approx(31 / 63)
assert stat.min_time_to_first_token == 2
assert stat.min_inter_token_latency == 2
assert stat.min_output_token_throughput == pytest.approx(3 / 7)
assert stat.max_time_to_first_token == 2
assert stat.max_inter_token_latency == 4
assert stat.max_output_token_throughput == pytest.approx(5 / 9)
assert stat.std_time_to_first_token == np.std([2, 2])
assert stat.std_inter_token_latency == np.std([2, 3, 3, 4])
assert stat.std_output_token_throughput == np.std([3 / 7, 5 / 9])

# experiment 2 statistics
stat = pd.get_statistics(infer_mode="request_rate", load_level=2.0)
assert stat.avg_time_to_first_token == 2.5
assert stat.avg_inter_token_latency == 3.2
assert stat.avg_output_token_throughput == pytest.approx(97 / 208)
assert stat.p50_time_to_first_token == 2.5
assert stat.p50_inter_token_latency == 3
assert stat.p50_output_token_throughput == pytest.approx(97 / 208)
assert stat.min_time_to_first_token == 2
assert stat.min_inter_token_latency == 1
assert stat.min_output_token_throughput == pytest.approx(4 / 13)
assert stat.max_time_to_first_token == 3
assert stat.max_inter_token_latency == 5
assert stat.max_output_token_throughput == pytest.approx(5 / 8)
assert stat.std_time_to_first_token == np.std([2, 3])
assert stat.std_inter_token_latency == np.std([1, 5, 5, 2, 3])
assert stat.std_output_token_throughput == np.std([4 / 13, 5 / 8])

# check non-existing profile data
with pytest.raises(KeyError):
pd.get_statistics(infer_mode="concurrency", load_level=30)

def test_llm_metrics_get_base_name(self) -> None:
"""Test get_base_name method in LLMMetrics class."""
metrics = LLMMetrics(
time_to_first_tokens=[1, 2, 3],
inter_token_latencies=[4, 5],
output_token_throughputs=[7, 8, 9],
)
assert metrics.get_base_name("time_to_first_tokens") == "time_to_first_token"
assert metrics.get_base_name("inter_token_latencies") == "inter_token_latency"
assert (
metrics.get_base_name("output_token_throughputs")
== "output_token_throughput"
)
with pytest.raises(ValueError):
metrics.get_base_name("hello1234")

0 comments on commit 9e763d3

Please sign in to comment.