Skip to content

Commit

Permalink
Capture LLM metrics from genai-perf in MA (#844)
Browse files Browse the repository at this point in the history
* Successfully reading from LLM CSV

* General cleanup

* All unit tests passing

* Fixing metric table typos

* Fixing typos
  • Loading branch information
nv-braf committed Apr 8, 2024
1 parent e2418f7 commit 6284f52
Show file tree
Hide file tree
Showing 11 changed files with 759 additions and 110 deletions.
3 changes: 3 additions & 0 deletions model_analyzer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,6 @@

# Model analyzer package name
PACKAGE_NAME = "triton-model-analyzer"

# GENAI-PERF CSV
GENAI_PERF_CSV = "profile_export_genai_perf.csv"
146 changes: 143 additions & 3 deletions model_analyzer/perf_analyzer/perf_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
import re
import signal
import tempfile
from csv import DictReader
from subprocess import STDOUT, Popen
from typing import Dict, List
from typing import Dict, List, Optional

import psutil

from model_analyzer.config.input.config_defaults import DEFAULT_MODEL_TYPE
from model_analyzer.constants import (
GENAI_PERF_CSV,
INTERVAL_SLEEP_TIME,
LOGGER_NAME,
MEASUREMENT_REQUEST_COUNT_STEP,
Expand All @@ -40,6 +43,16 @@
from model_analyzer.record.types.gpu_power_usage import GPUPowerUsage
from model_analyzer.record.types.gpu_used_memory import GPUUsedMemory
from model_analyzer.record.types.gpu_utilization import GPUUtilization
from model_analyzer.record.types.inter_token_latency_avg import InterTokenLatencyAvg
from model_analyzer.record.types.inter_token_latency_max import InterTokenLatencyMax
from model_analyzer.record.types.inter_token_latency_min import InterTokenLatencyMin
from model_analyzer.record.types.inter_token_latency_p25 import InterTokenLatencyP25
from model_analyzer.record.types.inter_token_latency_p50 import InterTokenLatencyP50
from model_analyzer.record.types.inter_token_latency_p75 import InterTokenLatencyP75
from model_analyzer.record.types.inter_token_latency_p90 import InterTokenLatencyP90
from model_analyzer.record.types.inter_token_latency_p95 import InterTokenLatencyP95
from model_analyzer.record.types.inter_token_latency_p99 import InterTokenLatencyP99
from model_analyzer.record.types.output_token_throughput import OutputTokenThroughput
from model_analyzer.record.types.perf_client_response_wait import PerfClientResponseWait
from model_analyzer.record.types.perf_client_send_recv import PerfClientSendRecv
from model_analyzer.record.types.perf_latency_avg import PerfLatencyAvg
Expand All @@ -53,6 +66,15 @@
)
from model_analyzer.record.types.perf_server_queue import PerfServerQueue
from model_analyzer.record.types.perf_throughput import PerfThroughput
from model_analyzer.record.types.time_to_first_token_avg import TimeToFirstTokenAvg
from model_analyzer.record.types.time_to_first_token_max import TimeToFirstTokenMax
from model_analyzer.record.types.time_to_first_token_min import TimeToFirstTokenMin
from model_analyzer.record.types.time_to_first_token_p25 import TimeToFirstTokenP25
from model_analyzer.record.types.time_to_first_token_p50 import TimeToFirstTokenP50
from model_analyzer.record.types.time_to_first_token_p75 import TimeToFirstTokenP75
from model_analyzer.record.types.time_to_first_token_p90 import TimeToFirstTokenP90
from model_analyzer.record.types.time_to_first_token_p95 import TimeToFirstTokenP95
from model_analyzer.record.types.time_to_first_token_p99 import TimeToFirstTokenP99

logger = logging.getLogger(LOGGER_NAME)

Expand Down Expand Up @@ -91,6 +113,28 @@ class PerfAnalyzer:
["gpu_used_memory", "Max GPU Memory Usage", GPUUsedMemory, "1000000"],
["gpu_free_memory", "Total GPU Memory", GPUFreeMemory, "1000000"]
]

llm_metric_table = [
["time_to_first_token_avg", "Time to First Token (ns) avg", TimeToFirstTokenAvg, "1000"],
["time_to_first_token_min", "Time to First Token (ns) min", TimeToFirstTokenMin, "1000"],
["time_to_first_token_max", "Time to First Token (ns) max", TimeToFirstTokenMax, "1000"],
["time_to_first_token_p99", "Time to First Token (ns) p99", TimeToFirstTokenP99, "1000"],
["time_to_first_token_p95", "Time to First Token (ns) p95", TimeToFirstTokenP95, "1000"],
["time_to_first_token_p90", "Time to First Token (ns) p90", TimeToFirstTokenP90, "1000"],
["time_to_first_token_p75", "Time to First Token (ns) p75", TimeToFirstTokenP75, "1000"],
["time_to_first_token_p50", "Time to First Token (ns) p50", TimeToFirstTokenP50, "1000"],
["time_to_first_token_p25", "Time to First Token (ns) p25", TimeToFirstTokenP25, "1000"],
["inter_token_latency_avg", "Inter Token Latency (ns) avg", InterTokenLatencyAvg, "1000"],
["inter_token_latency_min", "Inter Token Latency (ns) min", InterTokenLatencyMin, "1000"],
["inter_token_latency_max", "Inter Token Latency (ns) max", InterTokenLatencyMax, "1000"],
["inter_token_latency_p99", "Inter Token Latency (ns) p99", InterTokenLatencyP99, "1000"],
["inter_token_latency_p95", "Inter Token Latency (ns) p95", InterTokenLatencyP95, "1000"],
["inter_token_latency_p90", "Inter Token Latency (ns) p90", InterTokenLatencyP90, "1000"],
["inter_token_latency_p75", "Inter Token Latency (ns) p75", InterTokenLatencyP75, "1000"],
["inter_token_latency_p50", "Inter Token Latency (ns) p50", InterTokenLatencyP50, "1000"],
["inter_token_latency_p25", "Inter Token Latency (ns) p25", InterTokenLatencyP25, "1000"],
["output_token_throughput", "Output Token Throughput (per sec) avg", OutputTokenThroughput, "1"]
]
# yapf: enable

@staticmethod
Expand All @@ -109,7 +153,23 @@ def get_gpu_metrics():
]
return gpu_metrics

def __init__(self, path, config, max_retries, timeout, max_cpu_util):
@staticmethod
def get_llm_metrics():
llm_metrics = [
llm_metric[PerfAnalyzer.RECORD_CLASS]
for llm_metric in PerfAnalyzer.llm_metric_table
]
return llm_metrics

def __init__(
self,
path,
config,
max_retries,
timeout,
max_cpu_util,
model_type=DEFAULT_MODEL_TYPE,
):
"""
Parameters
----------
Expand All @@ -133,8 +193,10 @@ def __init__(self, path, config, max_retries, timeout, max_cpu_util):
self._timeout = timeout
self._output = ""
self._perf_records = {}
self._llm_records = {}
self._gpu_records = []
self._max_cpu_util = max_cpu_util
self._model_type = model_type

def run(self, metrics, env=None):
"""
Expand Down Expand Up @@ -195,7 +257,20 @@ def get_perf_records(self):
if self._perf_records:
return self._perf_records
raise TritonModelAnalyzerException(
"Attempted to get perf_analyzer results" "without calling run first."
"Attempted to get perf_analyzer results without calling run first."
)

def get_llm_records(self):
"""
Returns
-------
The LLM records from the last perf_analyzer run
"""

if self._llm_records:
return self._llm_records
raise TritonModelAnalyzerException(
"Attempted to get perf_analyzer results without calling run first."
)

def get_gpu_records(self):
Expand Down Expand Up @@ -438,6 +513,12 @@ def _is_multi_model(self):
return len(self._config.model_run_configs()) > 1

def _parse_outputs(self, metrics):
self._parse_generic_outputs(metrics)

if self._model_type == "LLM":
self._parse_llm_outputs(metrics)

def _parse_generic_outputs(self, metrics):
"""
Extract records from the Perf Analyzer run for each model
"""
Expand All @@ -464,6 +545,24 @@ def _parse_outputs(self, metrics):
for f in glob.glob(f"*{perf_config['latency-report-file']}"):
os.remove(f)

def _parse_llm_outputs(self, metrics):
"""
Extract records from the Perf Analyzer run for each model
"""

perf_config = self._config.model_run_configs()[0].perf_config()

logger.debug(f"Reading PA results from {GENAI_PERF_CSV}")
with open(GENAI_PERF_CSV, mode="r") as f:
csv_reader = csv.DictReader(f, delimiter=",")

# See test_perf_analyzer::test_pa_llm_csv_output() for CSV output example
self._llm_records[perf_config["model-name"]] = self._extract_llm_records(
metrics, csv_reader
)

os.remove(f)

def _extract_perf_records_from_row(
self, requested_metrics: List[Record], row_metrics: Dict[str, str]
) -> List[Record]:
Expand Down Expand Up @@ -526,6 +625,47 @@ def _extract_gpu_records_from_row(
self._cleanup_gpu_records(gpu_records)
return gpu_records

def _extract_llm_records(
self, requested_metrics: List[Record], csv_reader: DictReader
) -> List[Record]:
llm_records: List[Record] = []

for requested_metric in requested_metrics:
new_llm_record = self._get_llm_record_from_csv(requested_metric, csv_reader)
llm_records.append(new_llm_record)

return llm_records

def _get_llm_record_from_csv(
self, requested_metric: Record, csv_reader: DictReader
) -> Record:
for row in csv_reader:
for key, value in row.items():
metric_string = f"{row['Metric']} {key}"
llm_metric = self._find_corresponding_llm_metric_row(metric_string)

if (
llm_metric
and llm_metric[PerfAnalyzer.METRIC_TAG] == requested_metric.tag
):
adjusted_value = float(value) / float(
llm_metric[PerfAnalyzer.REDUCTION_FACTOR]
)

llm_record = llm_metric[PerfAnalyzer.RECORD_CLASS](adjusted_value) # type: ignore
return llm_record

raise TritonModelAnalyzerException(
f"Did not find {requested_metric.tag} in genai-perf CSV file"
)

def _find_corresponding_llm_metric_row(self, metric_string: str) -> Optional[List]:
for row in PerfAnalyzer.llm_metric_table:
if metric_string == row[PerfAnalyzer.CSV_STRING]:
return row

return None

def _cleanup_gpu_records(self, gpu_records):
# Recalculate GPUFreeMemory by removing the value of the associated GPUUsedMemory
# Remove any GPUFreeMemory records that don't have a matching GPUUsedMemory
Expand Down
60 changes: 60 additions & 0 deletions model_analyzer/record/types/inter_token_latency_p25.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import total_ordering

from model_analyzer.record.types.inter_token_latency_base import InterTokenLatencyBase


@total_ordering
class InterTokenLatencyP25(InterTokenLatencyBase):
"""
A record for perf_analyzer Inter token latency metric
"""

tag = "inter_token_latency_p25"

def __init__(self, value, timestamp=0):
"""
Parameters
----------
value : float
the latency extracted from the perf analyzer output
timestamp : float
Elapsed time from start of program
"""

super().__init__(value, timestamp)

@classmethod
def header(cls, aggregation_tag=False):
"""
Parameters
----------
aggregation_tag: bool
An optional tag that may be displayed
as part of the header indicating that
this record has been aggregated using
max, min or average etc.
Returns
-------
str
The full name of the
metric.
"""

return "p25 Inter Token Latency (ms)"
60 changes: 60 additions & 0 deletions model_analyzer/record/types/inter_token_latency_p50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import total_ordering

from model_analyzer.record.types.inter_token_latency_base import InterTokenLatencyBase


@total_ordering
class InterTokenLatencyP50(InterTokenLatencyBase):
"""
A record for perf_analyzer Inter token latency metric
"""

tag = "inter_token_latency_p50"

def __init__(self, value, timestamp=0):
"""
Parameters
----------
value : float
the latency extracted from the perf analyzer output
timestamp : float
Elapsed time from start of program
"""

super().__init__(value, timestamp)

@classmethod
def header(cls, aggregation_tag=False):
"""
Parameters
----------
aggregation_tag: bool
An optional tag that may be displayed
as part of the header indicating that
this record has been aggregated using
max, min or average etc.
Returns
-------
str
The full name of the
metric.
"""

return "p50 Inter Token Latency (ms)"
2 changes: 1 addition & 1 deletion model_analyzer/record/types/inter_token_latency_p90.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


@total_ordering
class InterTokenLatencyP99(InterTokenLatencyBase):
class InterTokenLatencyP90(InterTokenLatencyBase):
"""
A record for perf_analyzer Inter token latency metric
"""
Expand Down
Loading

0 comments on commit 6284f52

Please sign in to comment.