From 5779fb718cce91c43e662c9a36cc000be9f0694f Mon Sep 17 00:00:00 2001 From: Hyunjae Woo <107147848+nv-hwoo@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:22:29 -0700 Subject: [PATCH] Add profile data parser for image retriever models (#43) * add profile data parser for image retriever * add checks for metrics --- .../export_data/console_exporter.py | 2 + genai-perf/genai_perf/main.py | 8 +- genai-perf/genai_perf/metrics/__init__.py | 1 + .../metrics/image_retrieval_metrics.py | 60 + genai-perf/genai_perf/metrics/statistics.py | 3 + .../profile_data_parser/__init__.py | 3 + .../image_retrieval_profile_data_parser.py | 84 ++ .../profile_data_parser.py | 5 + ...est_image_retrieval_profile_data_parser.py | 140 ++ .../tests/test_llm_profile_data_parser.py | 1206 ++++++++--------- genai-perf/tests/test_utils.py | 44 + 11 files changed, 901 insertions(+), 655 deletions(-) create mode 100755 genai-perf/genai_perf/metrics/image_retrieval_metrics.py create mode 100755 genai-perf/genai_perf/profile_data_parser/image_retrieval_profile_data_parser.py create mode 100644 genai-perf/tests/test_image_retrieval_profile_data_parser.py create mode 100644 genai-perf/tests/test_utils.py diff --git a/genai-perf/genai_perf/export_data/console_exporter.py b/genai-perf/genai_perf/export_data/console_exporter.py index 460fe597..56e3ba08 100644 --- a/genai-perf/genai_perf/export_data/console_exporter.py +++ b/genai-perf/genai_perf/export_data/console_exporter.py @@ -47,6 +47,8 @@ def _get_title(self): return "Embeddings Metrics" elif self._args.endpoint_type == "rankings": return "Rankings Metrics" + elif self._args.endpoint_type == "image_retrieval": + return "Image Retrieval Metrics" else: return "LLM Metrics" diff --git a/genai-perf/genai_perf/main.py b/genai-perf/genai_perf/main.py index 27eb6182..f7af76ea 100755 --- a/genai-perf/genai_perf/main.py +++ b/genai-perf/genai_perf/main.py @@ -38,7 +38,11 @@ from genai_perf.llm_inputs.llm_inputs import LlmInputs from genai_perf.plots.plot_config_parser import PlotConfigParser from genai_perf.plots.plot_manager import PlotManager -from genai_perf.profile_data_parser import LLMProfileDataParser, ProfileDataParser +from genai_perf.profile_data_parser import ( + ImageRetrievalProfileDataParser, + LLMProfileDataParser, + ProfileDataParser, +) from genai_perf.tokenizer import Tokenizer, get_tokenizer @@ -95,6 +99,8 @@ def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None: def calculate_metrics(args: Namespace, tokenizer: Tokenizer) -> ProfileDataParser: if args.endpoint_type in ["embeddings", "rankings"]: return ProfileDataParser(args.profile_export_file) + elif args.endpoint_type == "image_retrieval": + return ImageRetrievalProfileDataParser(args.profile_export_file) else: return LLMProfileDataParser( filename=args.profile_export_file, diff --git a/genai-perf/genai_perf/metrics/__init__.py b/genai-perf/genai_perf/metrics/__init__.py index b3cdd6dc..3812d1ed 100644 --- a/genai-perf/genai_perf/metrics/__init__.py +++ b/genai-perf/genai_perf/metrics/__init__.py @@ -24,6 +24,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from genai_perf.metrics.image_retrieval_metrics import ImageRetrievalMetrics from genai_perf.metrics.llm_metrics import LLMMetrics from genai_perf.metrics.metrics import MetricMetadata, Metrics from genai_perf.metrics.statistics import Statistics diff --git a/genai-perf/genai_perf/metrics/image_retrieval_metrics.py b/genai-perf/genai_perf/metrics/image_retrieval_metrics.py new file mode 100755 index 00000000..a7ff3547 --- /dev/null +++ b/genai-perf/genai_perf/metrics/image_retrieval_metrics.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List + +from genai_perf.metrics.metrics import MetricMetadata, Metrics + + +class ImageRetrievalMetrics(Metrics): + """A simple dataclass that holds core Image Retrieval performance metrics.""" + + IMAGE_RETRIEVAL_REQUEST_METRICS = [ + MetricMetadata("image_throughput", "images/sec"), + MetricMetadata("image_latency", "ms/image"), + ] + + def __init__( + self, + request_throughputs: List[float] = [], + request_latencies: List[int] = [], + image_throughputs: List[int] = [], + image_latencies: List[int] = [], + ) -> None: + super().__init__(request_throughputs, request_latencies) + self.image_throughputs = image_throughputs + self.image_latencies = image_latencies + + # add base name mapping + self._base_names["image_throughputs"] = "image_throughput" + self._base_names["image_latencies"] = "image_latency" + + @property + def request_metrics(self) -> List[MetricMetadata]: + base_metrics = super().request_metrics # base metrics + return base_metrics + self.IMAGE_RETRIEVAL_REQUEST_METRICS diff --git a/genai-perf/genai_perf/metrics/statistics.py b/genai-perf/genai_perf/metrics/statistics.py index f0d12cef..df4371df 100755 --- a/genai-perf/genai_perf/metrics/statistics.py +++ b/genai-perf/genai_perf/metrics/statistics.py @@ -131,6 +131,8 @@ def _add_units(self, key) -> None: self._stats_dict[key]["unit"] = "ms" elif key == "request_throughput": self._stats_dict[key]["unit"] = "requests/sec" + elif key == "image_throughput": + self._stats_dict[key]["unit"] = "pages/sec" elif key.startswith("output_token_throughput"): self._stats_dict[key]["unit"] = "tokens/sec" elif "sequence_length" in key: @@ -168,6 +170,7 @@ def _is_time_metric(self, field: str) -> bool: "inter_token_latency", "time_to_first_token", "request_latency", + "image_latency", ] return field in time_metrics diff --git a/genai-perf/genai_perf/profile_data_parser/__init__.py b/genai-perf/genai_perf/profile_data_parser/__init__.py index 2e7798c4..a05ce57a 100644 --- a/genai-perf/genai_perf/profile_data_parser/__init__.py +++ b/genai-perf/genai_perf/profile_data_parser/__init__.py @@ -24,6 +24,9 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from genai_perf.profile_data_parser.image_retrieval_profile_data_parser import ( + ImageRetrievalProfileDataParser, +) from genai_perf.profile_data_parser.llm_profile_data_parser import LLMProfileDataParser from genai_perf.profile_data_parser.profile_data_parser import ( ProfileDataParser, diff --git a/genai-perf/genai_perf/profile_data_parser/image_retrieval_profile_data_parser.py b/genai-perf/genai_perf/profile_data_parser/image_retrieval_profile_data_parser.py new file mode 100755 index 00000000..7c1f7486 --- /dev/null +++ b/genai-perf/genai_perf/profile_data_parser/image_retrieval_profile_data_parser.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 + +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from pathlib import Path + +from genai_perf.metrics import ImageRetrievalMetrics +from genai_perf.profile_data_parser.profile_data_parser import ProfileDataParser +from genai_perf.utils import load_json_str + + +class ImageRetrievalProfileDataParser(ProfileDataParser): + """Calculate and aggregate all the Image Retrieval performance statistics + across the Perf Analyzer profile results. + """ + + def __init__(self, filename: Path) -> None: + super().__init__(filename) + + def _parse_requests(self, requests: dict) -> ImageRetrievalMetrics: + """Parse each request in profile data to extract core metrics.""" + min_req_timestamp, max_res_timestamp = float("inf"), 0 + request_latencies = [] + image_throughputs = [] + image_latencies = [] + + for request in requests: + req_timestamp = request["timestamp"] + res_timestamps = request["response_timestamps"] + req_inputs = request["request_inputs"] + + # track entire benchmark duration + min_req_timestamp = min(min_req_timestamp, req_timestamp) + max_res_timestamp = max(max_res_timestamp, res_timestamps[-1]) + + # request latencies + req_latency_ns = res_timestamps[-1] - req_timestamp + request_latencies.append(req_latency_ns) + + payload = load_json_str(req_inputs["payload"]) + contents = payload["messages"][0]["content"] + num_images = len([c for c in contents if c["type"] == "image_url"]) + + # image throughput + req_latency_s = req_latency_ns / 1e9 # to seconds + image_throughputs.append(num_images / req_latency_s) + + # image latencies + image_latencies.append(req_latency_ns / num_images) + + # request throughput + benchmark_duration = (max_res_timestamp - min_req_timestamp) / 1e9 # to seconds + request_throughputs = [len(requests) / benchmark_duration] + + return ImageRetrievalMetrics( + request_throughputs, + request_latencies, + image_throughputs, + image_latencies, + ) diff --git a/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py b/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py index 245afb2c..0ae036b5 100755 --- a/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py +++ b/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py @@ -41,6 +41,7 @@ class ResponseFormat(Enum): OPENAI_EMBEDDINGS = auto() OPENAI_VISION = auto() RANKINGS = auto() + IMAGE_RETRIEVAL = auto() TRITON = auto() @@ -75,6 +76,8 @@ def _get_profile_metadata(self, data: dict) -> None: self._response_format = ResponseFormat.OPENAI_EMBEDDINGS elif data["endpoint"] == "v1/ranking": self._response_format = ResponseFormat.RANKINGS + elif data["endpoint"] == "v1/infer": + self._response_format = ResponseFormat.IMAGE_RETRIEVAL else: # (TPA-66) add PA metadata to handle this case # When endpoint field is either empty or custom endpoint, fall @@ -93,6 +96,8 @@ def _get_profile_metadata(self, data: dict) -> None: self._response_format = ResponseFormat.OPENAI_EMBEDDINGS elif "ranking" in response: self._response_format = ResponseFormat.RANKINGS + elif "image_retrieval" in response: + self._response_format = ResponseFormat.IMAGE_RETRIEVAL else: raise RuntimeError("Unknown OpenAI response format.") diff --git a/genai-perf/tests/test_image_retrieval_profile_data_parser.py b/genai-perf/tests/test_image_retrieval_profile_data_parser.py new file mode 100644 index 00000000..5dd5dd31 --- /dev/null +++ b/genai-perf/tests/test_image_retrieval_profile_data_parser.py @@ -0,0 +1,140 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +from pathlib import Path +from typing import cast +from unittest.mock import mock_open, patch + +import pytest +from genai_perf.metrics import ImageRetrievalMetrics, Statistics +from genai_perf.profile_data_parser import ImageRetrievalProfileDataParser + +from .test_utils import check_statistics, ns_to_sec + + +def check_image_retrieval_metrics( + m1: ImageRetrievalMetrics, m2: ImageRetrievalMetrics +) -> None: + assert m1.request_latencies == m2.request_latencies + assert m1.request_throughputs == pytest.approx(m2.request_throughputs) + assert m1.image_latencies == m2.image_latencies + assert m1.image_throughputs == pytest.approx(m2.image_throughputs) + + +class TestImageRetrievalProfileDataParser: + + image_retrieval_profile_data = { + "experiments": [ + { + "experiment": {"mode": "concurrency", "value": 10}, + "requests": [ + { + "timestamp": 1, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"image1"}},{"type":"image_url","image_url":{"url":"image2"}}]}],"model":"yolox"}' + }, + "response_timestamps": [3], + "response_outputs": [ + { + "response": '{"object":"list","data":[],"model":"yolox","usage":null}' + } + ], + }, + { + "timestamp": 3, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"image1"}},{"type":"image_url","image_url":{"url":"image2"}},{"type":"image_url","image_url":{"url":"image3"}}]}],"model":"yolox"}' + }, + "response_timestamps": [7], + "response_outputs": [ + { + "response": '{"object":"list","data":[],"model":"yolox","usage":null}' + } + ], + }, + ], + } + ], + "version": "", + "service_kind": "openai", + "endpoint": "v1/infer", + } + + @patch("pathlib.Path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data=json.dumps(image_retrieval_profile_data), + ) + @pytest.mark.parametrize( + "infer_mode, load_level, expected_metrics", + [ + ( + "concurrency", + "10", + { + "request_throughputs": [1 / ns_to_sec(3)], + "request_latencies": [2, 4], + "image_throughputs": [1 / ns_to_sec(1), 3 / ns_to_sec(4)], + "image_latencies": [1, 4 / 3], + }, + ), + ], + ) + def test_image_retrieval_profile_data( + self, + mock_exists, + mock_file, + infer_mode, + load_level, + expected_metrics, + ) -> None: + """Collect image retrieval metrics from profile export data and check values. + + Metrics + * request throughputs + - [2 / (7 - 1)] = [1 / ns_to_sec(3)] + * request latencies + - [3 - 1, 7 - 3] = [2, 4] + * image throughputs + - [2 / (3 - 1), 3 / (7 - 3)] = [1 / ns_to_sec(1), 3 / ns_to_sec(4)] + * image latencies + - [(3 - 1) / 2, (7 - 3) / 3] = [1, 4/3] + + """ + pd = ImageRetrievalProfileDataParser( + filename=Path("image_retrieval_profile_export.json") + ) + + statistics = pd.get_statistics(infer_mode="concurrency", load_level="10") + metrics = cast(ImageRetrievalMetrics, statistics.metrics) + + expected_metrics = ImageRetrievalMetrics(**expected_metrics) + expected_statistics = Statistics(expected_metrics) + + check_image_retrieval_metrics(metrics, expected_metrics) + check_statistics(statistics, expected_statistics) diff --git a/genai-perf/tests/test_llm_profile_data_parser.py b/genai-perf/tests/test_llm_profile_data_parser.py index b8f9847c..d3f213cb 100644 --- a/genai-perf/tests/test_llm_profile_data_parser.py +++ b/genai-perf/tests/test_llm_profile_data_parser.py @@ -24,31 +24,17 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import json -from io import StringIO from pathlib import Path -from typing import Any, List, Union, cast +from typing import cast +from unittest.mock import patch -import numpy as np import pytest from genai_perf.metrics import LLMMetrics from genai_perf.metrics.statistics import Statistics from genai_perf.profile_data_parser import LLMProfileDataParser from genai_perf.tokenizer import DEFAULT_TOKENIZER, get_tokenizer - -def ns_to_sec(ns: int) -> Union[int, float]: - """Convert from nanosecond to second.""" - return ns / 1e9 - - -def check_statistics(s1: Statistics, s2: Statistics) -> None: - s1_dict = s1.stats_dict - s2_dict = s2.stats_dict - for metric in s1_dict.keys(): - for stat_name, value in s1_dict[metric].items(): - if stat_name != "unit": - assert s2_dict[metric][stat_name] == pytest.approx(value) +from .test_utils import check_statistics, ns_to_sec def check_llm_metrics(m1: LLMMetrics, m2: LLMMetrics) -> None: @@ -65,380 +51,78 @@ def check_llm_metrics(m1: LLMMetrics, m2: LLMMetrics) -> None: class TestLLMProfileDataParser: - @pytest.fixture - def mock_read_write(self, monkeypatch: pytest.MonkeyPatch) -> List[str]: - """ - This function will mock the open function for specific files: - - - For "triton_profile_export.json", it will read and return the - contents of self.triton_profile_data - - For "openai_profile_export.json", it will read and return the - contents of self.openai_profile_data - - For "profile_export.csv", it will capture all data written to - the file, and return it as the return value of this function - - For all other files, it will behave like the normal open function - """ - - written_data = [] - - original_open = open - - def custom_open(filename, *args, **kwargs): - def write(self: Any, content: str) -> int: - written_data.append(content) - return len(content) - - if filename == "triton_profile_export.json": - tmp_file = StringIO(json.dumps(self.triton_profile_data)) - return tmp_file - elif filename == "openai_profile_export.json": - tmp_file = StringIO(json.dumps(self.openai_profile_data)) - return tmp_file - elif filename == "openai_vlm_profile_export.json": - tmp_file = StringIO(json.dumps(self.openai_vlm_profile_data)) - return tmp_file - elif filename == "tensorrtllm_engine_profile_export.json": - tmp_file = StringIO(json.dumps(self.tensorrtllm_engine_profile_data)) - return tmp_file - elif filename == "empty_profile_export.json": - tmp_file = StringIO(json.dumps(self.empty_profile_data)) - return tmp_file - elif filename == "unfinished_responses_profile_export.json": - tmp_file = StringIO(json.dumps(self.unfinished_responses_profile_data)) - return tmp_file - elif filename == "profile_export.csv": - tmp_file = StringIO() - tmp_file.write = write.__get__(tmp_file) - return tmp_file - else: - return original_open(filename, *args, **kwargs) - - monkeypatch.setattr("builtins.open", custom_open) - - return written_data - - def test_triton_llm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> None: - """Collect LLM metrics from profile export data and check values. - - Metrics - * time to first tokens - - experiment 1: [3 - 1, 4 - 2] = [2, 2] - - experiment 2: [7 - 5, 6 - 3] = [2, 3] - * inter token latencies - - experiment 1: [((8 - 1) - 2)/(3 - 1), ((11 - 2) - 2)/(6 - 1)] - : [2.5, 1.4] - : [2, 1] # rounded - - experiment 2: [((18 - 5) - 2)/(4 - 1), ((11 - 3) - 3)/(6 - 1)] - : [11/3, 1] - : [4, 1] # rounded - * output token throughputs per request - - experiment 1: [3/(8 - 1), 6/(11 - 2)] = [3/7, 6/9] - - experiment 2: [4/(18 - 5), 6/(11 - 3)] = [4/13, 6/8] - * output token throughputs - - experiment 1: [(3 + 6)/(11 - 1)] = [9/10] - - experiment 2: [(4 + 6)/(18 - 3)] = [2/3] - * output sequence lengths - - experiment 1: [3, 6] - - experiment 2: [4, 6] - * input sequence lengths - - experiment 1: [3, 4] - - experiment 2: [3, 4] - """ - tokenizer = get_tokenizer(DEFAULT_TOKENIZER) - pd = LLMProfileDataParser( - filename=Path("triton_profile_export.json"), - tokenizer=tokenizer, - ) - - # experiment 1 metrics & statistics - stat_obj = pd.get_statistics(infer_mode="concurrency", load_level="10") - metrics = stat_obj.metrics - stat = stat_obj.stats_dict - - assert isinstance(metrics, LLMMetrics) - - assert metrics.time_to_first_tokens == [2, 2] - assert metrics.inter_token_latencies == [2, 1] - ottpr = [3 / ns_to_sec(7), 6 / ns_to_sec(9)] - assert metrics.output_token_throughputs_per_request == pytest.approx(ottpr) - ott = [9 / ns_to_sec(10)] - assert metrics.output_token_throughputs == pytest.approx(ott) - assert metrics.output_sequence_lengths == [3, 6] - assert metrics.input_sequence_lengths == [3, 4] - - # Disable Pylance warnings for dynamically set attributes due to Statistics - # not having strict attributes listed. - assert stat["time_to_first_token"]["avg"] == 2 # type: ignore - assert stat["inter_token_latency"]["avg"] == 1.5 # type: ignore - assert stat["output_token_throughput_per_request"]["avg"] == pytest.approx( # type: ignore - np.mean(ottpr) - ) - assert stat["output_sequence_length"]["avg"] == 4.5 # type: ignore - assert stat["input_sequence_length"]["avg"] == 3.5 # type: ignore - - assert stat["time_to_first_token"]["p50"] == 2 # type: ignore - assert stat["inter_token_latency"]["p50"] == 1.5 # type: ignore - assert stat["output_token_throughput_per_request"]["p50"] == pytest.approx( # type: ignore - np.percentile(ottpr, 50) - ) - assert stat["output_sequence_length"]["p50"] == 4.5 # type: ignore - assert stat["input_sequence_length"]["p50"] == 3.5 # type: ignore - - assert stat["time_to_first_token"]["min"] == 2 # type: ignore - assert stat["inter_token_latency"]["min"] == 1 # type: ignore - min_ottpr = 3 / ns_to_sec(7) - assert stat["output_token_throughput_per_request"]["min"] == pytest.approx(min_ottpr) # type: ignore - assert stat["output_sequence_length"]["min"] == 3 # type: ignore - assert stat["input_sequence_length"]["min"] == 3 # type: ignore - - assert stat["time_to_first_token"]["max"] == 2 # type: ignore - assert stat["inter_token_latency"]["max"] == 2 # type: ignore - max_ottpr = 6 / ns_to_sec(9) - assert stat["output_token_throughput_per_request"]["max"] == pytest.approx(max_ottpr) # type: ignore - assert stat["output_sequence_length"]["max"] == 6 # type: ignore - assert stat["input_sequence_length"]["max"] == 4 # type: ignore - - assert stat["time_to_first_token"]["std"] == np.std([2, 2]) # type: ignore - assert stat["inter_token_latency"]["std"] == np.std([2, 1]) # type: ignore - assert stat["output_token_throughput_per_request"]["std"] == pytest.approx( # type: ignore - np.std(ottpr) - ) - assert stat["output_sequence_length"]["std"] == np.std([3, 6]) # type: ignore - assert stat["input_sequence_length"]["std"] == np.std([3, 4]) # type: ignore - - oott = 9 / ns_to_sec(10) - assert stat["output_token_throughput"]["avg"] == pytest.approx(oott) # type: ignore - - # experiment 2 statistics - stat_obj = pd.get_statistics(infer_mode="request_rate", load_level="2.0") - metrics = stat_obj.metrics - stat = stat_obj.stats_dict - assert isinstance(metrics, LLMMetrics) - - assert metrics.time_to_first_tokens == [2, 3] - assert metrics.inter_token_latencies == [4, 1] - ottpr = [4 / ns_to_sec(13), 6 / ns_to_sec(8)] - assert metrics.output_token_throughputs_per_request == pytest.approx(ottpr) - ott = [2 / ns_to_sec(3)] - assert metrics.output_token_throughputs == pytest.approx(ott) - assert metrics.output_sequence_lengths == [4, 6] - assert metrics.input_sequence_lengths == [3, 4] - - assert stat["time_to_first_token"]["avg"] == pytest.approx(2.5) # type: ignore - assert stat["inter_token_latency"]["avg"] == pytest.approx(2.5) # type: ignore - assert stat["output_token_throughput_per_request"]["avg"] == pytest.approx( # type: ignore - np.mean(ottpr) - ) - assert stat["output_sequence_length"]["avg"] == 5 # type: ignore - assert stat["input_sequence_length"]["avg"] == 3.5 # type: ignore - - assert stat["time_to_first_token"]["p50"] == pytest.approx(2.5) # type: ignore - assert stat["inter_token_latency"]["p50"] == pytest.approx(2.5) # type: ignore - assert stat["output_token_throughput_per_request"]["p50"] == pytest.approx( # type: ignore - np.percentile(ottpr, 50) - ) - assert stat["output_sequence_length"]["p50"] == 5 # type: ignore - assert stat["input_sequence_length"]["p50"] == 3.5 # type: ignore - - assert stat["time_to_first_token"]["min"] == pytest.approx(2) # type: ignore - assert stat["inter_token_latency"]["min"] == pytest.approx(1) # type: ignore - min_ottpr = 4 / ns_to_sec(13) - assert stat["output_token_throughput_per_request"]["min"] == pytest.approx(min_ottpr) # type: ignore - assert stat["output_sequence_length"]["min"] == 4 # type: ignore - assert stat["input_sequence_length"]["min"] == 3 # type: ignore - - assert stat["time_to_first_token"]["max"] == pytest.approx(3) # type: ignore - assert stat["inter_token_latency"]["max"] == pytest.approx(4) # type: ignore - max_ottpr = 6 / ns_to_sec(8) - assert stat["output_token_throughput_per_request"]["max"] == pytest.approx(max_ottpr) # type: ignore - assert stat["output_sequence_length"]["max"] == 6 # type: ignore - assert stat["input_sequence_length"]["max"] == 4 # type: ignore - - assert stat["time_to_first_token"]["std"] == np.std([2, 3]) * (1) # type: ignore - assert stat["inter_token_latency"]["std"] == np.std([4, 1]) * (1) # type: ignore - assert stat["output_token_throughput_per_request"]["std"] == pytest.approx( # type: ignore - np.std(ottpr) - ) - assert stat["output_sequence_length"]["std"] == np.std([4, 6]) # type: ignore - assert stat["input_sequence_length"]["std"] == np.std([3, 4]) # type: ignore - - oott = 2 / ns_to_sec(3) - assert stat["output_token_throughput"]["avg"] == pytest.approx(oott) # type: ignore - - # check non-existing profile data - with pytest.raises(KeyError): - pd.get_statistics(infer_mode="concurrency", load_level="30") - - def test_openai_llm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> None: - """Collect LLM metrics from profile export data and check values. - - Metrics - * time to first tokens - - experiment 1: [5 - 1, 7 - 2] = [4, 5] - * inter token latencies - - experiment 1: [((12 - 1) - 4)/(3 - 1), ((15 - 2) - 5)/(6 - 1)] - : [3.5, 1.6] - : [4, 2] # rounded - * output token throughputs per request - - experiment 1: [3/(12 - 1), 6/(15 - 2)] = [3/11, 6/13] - * output token throughputs - - experiment 1: [(3 + 6)/(15 - 1)] = [9/14] - * output sequence lengths - - experiment 1: [3, 6] - * input sequence lengths - - experiment 1: [3, 4] - """ - tokenizer = get_tokenizer(DEFAULT_TOKENIZER) - pd = LLMProfileDataParser( - filename=Path("openai_profile_export.json"), - tokenizer=tokenizer, - ) - - # experiment 1 statistics - stat_obj = pd.get_statistics(infer_mode="concurrency", load_level="10") - metrics = stat_obj.metrics - stat = stat_obj.stats_dict - assert isinstance(metrics, LLMMetrics) - - assert metrics.time_to_first_tokens == [4, 5] - assert metrics.inter_token_latencies == [4, 2] - ottpr = [3 / ns_to_sec(11), 6 / ns_to_sec(13)] - assert metrics.output_token_throughputs_per_request == pytest.approx(ottpr) - ott = [9 / ns_to_sec(14)] - assert metrics.output_token_throughputs == pytest.approx(ott) - assert metrics.output_sequence_lengths == [3, 6] - assert metrics.input_sequence_lengths == [3, 4] - - assert stat["time_to_first_token"]["avg"] == pytest.approx(4.5) # type: ignore - assert stat["inter_token_latency"]["avg"] == pytest.approx(3) # type: ignore - assert stat["output_token_throughput_per_request"]["avg"] == pytest.approx( # type: ignore - np.mean(ottpr) - ) - assert stat["output_sequence_length"]["avg"] == 4.5 # type: ignore - assert stat["input_sequence_length"]["avg"] == 3.5 # type: ignore - - assert stat["time_to_first_token"]["p50"] == pytest.approx(4.5) # type: ignore - assert stat["inter_token_latency"]["p50"] == pytest.approx(3) # type: ignore - assert stat["output_token_throughput_per_request"]["p50"] == pytest.approx( # type: ignore - np.percentile(ottpr, 50) - ) - assert stat["output_sequence_length"]["p50"] == 4.5 # type: ignore - assert stat["input_sequence_length"]["p50"] == 3.5 # type: ignore - - assert stat["time_to_first_token"]["min"] == pytest.approx(4) # type: ignore - assert stat["inter_token_latency"]["min"] == pytest.approx(2) # type: ignore - min_ottpr = 3 / ns_to_sec(11) - assert stat["output_token_throughput_per_request"]["min"] == pytest.approx(min_ottpr) # type: ignore - assert stat["output_sequence_length"]["min"] == 3 # type: ignore - assert stat["input_sequence_length"]["min"] == 3 # type: ignore - - assert stat["time_to_first_token"]["max"] == pytest.approx(5) # type: ignore - assert stat["inter_token_latency"]["max"] == pytest.approx(4) # type: ignore - max_ottpr = 6 / ns_to_sec(13) - assert stat["output_token_throughput_per_request"]["max"] == pytest.approx(max_ottpr) # type: ignore - assert stat["output_sequence_length"]["max"] == 6 # type: ignore - assert stat["input_sequence_length"]["max"] == 4 # type: ignore - - assert stat["time_to_first_token"]["std"] == np.std([4, 5]) * (1) # type: ignore - assert stat["inter_token_latency"]["std"] == np.std([4, 2]) * (1) # type: ignore - assert stat["output_token_throughput_per_request"]["std"] == pytest.approx( # type: ignore - np.std(ottpr) - ) - assert stat["output_sequence_length"]["std"] == np.std([3, 6]) # type: ignore - assert stat["input_sequence_length"]["std"] == np.std([3, 4]) # type: ignore - - oott = 9 / ns_to_sec(14) - assert stat["output_token_throughput"]["avg"] == pytest.approx(oott) # type: ignore - - # check non-existing profile data - with pytest.raises(KeyError): - pd.get_statistics(infer_mode="concurrency", load_level="40") - - def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> None: - """Collect LLM metrics from profile export data and check values. - - Metrics - * time to first tokens - - experiment 1: [5 - 1, 7 - 2] = [4, 5] - * inter token latencies - - experiment 1: [((12 - 1) - 4)/(3 - 1), ((15 - 2) - 5)/(6 - 1)] - : [3.5, 1.6] - : [4, 2] # rounded - * output token throughputs per request - - experiment 1: [3/(12 - 1), 6/(15 - 2)] = [3/11, 6/13] - * output token throughputs - - experiment 1: [(3 + 6)/(15 - 1)] = [9/14] - * output sequence lengths - - experiment 1: [3, 6] - * input sequence lengths - - experiment 1: [3, 4] - """ - tokenizer = get_tokenizer(DEFAULT_TOKENIZER) - pd = LLMProfileDataParser( - filename=Path("openai_vlm_profile_export.json"), - tokenizer=tokenizer, - ) - - # experiment 1 statistics - stat_obj = pd.get_statistics(infer_mode="concurrency", load_level="10") - metrics = stat_obj.metrics - stat = stat_obj.stats_dict - assert isinstance(metrics, LLMMetrics) - - assert metrics.time_to_first_tokens == [4, 5] - assert metrics.inter_token_latencies == [4, 2] - ottpr = [3 / ns_to_sec(11), 6 / ns_to_sec(13)] - assert metrics.output_token_throughputs_per_request == pytest.approx(ottpr) - ott = [9 / ns_to_sec(14)] - assert metrics.output_token_throughputs == pytest.approx(ott) - assert metrics.output_sequence_lengths == [3, 6] - assert metrics.input_sequence_lengths == [3, 4] - - assert stat["time_to_first_token"]["avg"] == pytest.approx(4.5) # type: ignore - assert stat["inter_token_latency"]["avg"] == pytest.approx(3) # type: ignore - assert stat["output_token_throughput_per_request"]["avg"] == pytest.approx( # type: ignore - np.mean(ottpr) - ) - assert stat["output_sequence_length"]["avg"] == 4.5 # type: ignore - assert stat["input_sequence_length"]["avg"] == 3.5 # type: ignore - - assert stat["time_to_first_token"]["p50"] == pytest.approx(4.5) # type: ignore - assert stat["inter_token_latency"]["p50"] == pytest.approx(3) # type: ignore - assert stat["output_token_throughput_per_request"]["p50"] == pytest.approx( # type: ignore - np.percentile(ottpr, 50) - ) - assert stat["output_sequence_length"]["p50"] == 4.5 # type: ignore - assert stat["input_sequence_length"]["p50"] == 3.5 # type: ignore - - assert stat["time_to_first_token"]["min"] == pytest.approx(4) # type: ignore - assert stat["inter_token_latency"]["min"] == pytest.approx(2) # type: ignore - min_ottpr = 3 / ns_to_sec(11) - assert stat["output_token_throughput_per_request"]["min"] == pytest.approx(min_ottpr) # type: ignore - assert stat["output_sequence_length"]["min"] == 3 # type: ignore - assert stat["input_sequence_length"]["min"] == 3 # type: ignore - - assert stat["time_to_first_token"]["max"] == pytest.approx(5) # type: ignore - assert stat["inter_token_latency"]["max"] == pytest.approx(4) # type: ignore - max_ottpr = 6 / ns_to_sec(13) - assert stat["output_token_throughput_per_request"]["max"] == pytest.approx(max_ottpr) # type: ignore - assert stat["output_sequence_length"]["max"] == 6 # type: ignore - assert stat["input_sequence_length"]["max"] == 4 # type: ignore - - assert stat["time_to_first_token"]["std"] == np.std([4, 5]) * (1) # type: ignore - assert stat["inter_token_latency"]["std"] == np.std([4, 2]) * (1) # type: ignore - assert stat["output_token_throughput_per_request"]["std"] == pytest.approx( # type: ignore - np.std(ottpr) - ) - assert stat["output_sequence_length"]["std"] == np.std([3, 6]) # type: ignore - assert stat["input_sequence_length"]["std"] == np.std([3, 4]) # type: ignore - oott = 9 / ns_to_sec(14) - assert stat["output_token_throughput"]["avg"] == pytest.approx(oott) # type: ignore - - # check non-existing profile data - with pytest.raises(KeyError): - pd.get_statistics(infer_mode="concurrency", load_level="40") + ############################### + # TRITON + ############################### + triton_profile_data = { + "service_kind": "triton", + "endpoint": "", + "experiments": [ + { + "experiment": { + "mode": "concurrency", + "value": 10, + }, + "requests": [ + { + "timestamp": 1, + "request_inputs": {"text_input": "This is test"}, + "response_timestamps": [3, 5, 8], + "response_outputs": [ + {"text_output": "I"}, + {"text_output": " like"}, + {"text_output": " dogs"}, + ], + }, + { + "timestamp": 2, + "request_inputs": {"text_input": "This is test too"}, + "response_timestamps": [4, 7, 11], + "response_outputs": [ + {"text_output": "I"}, + {"text_output": " don't"}, + {"text_output": " cook food"}, + ], + }, + ], + }, + { + "experiment": { + "mode": "request_rate", + "value": 2.0, + }, + "requests": [ + { + "timestamp": 5, + "request_inputs": {"text_input": "This is test"}, + "response_timestamps": [7, 8, 13, 18], + "response_outputs": [ + {"text_output": "cat"}, + {"text_output": " is"}, + {"text_output": " cool"}, + {"text_output": " too"}, + ], + }, + { + "timestamp": 3, + "request_inputs": {"text_input": "This is test too"}, + "response_timestamps": [6, 8, 11], + "response_outputs": [ + {"text_output": "it's"}, + {"text_output": " very"}, + {"text_output": " simple work"}, + ], + }, + ], + }, + ], + } + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=triton_profile_data, + ) @pytest.mark.parametrize( "infer_mode, load_level, expected_metrics", [ @@ -449,13 +133,13 @@ def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N "request_latencies": [7, 9], "request_throughputs": [1 / ns_to_sec(5)], "time_to_first_tokens": [2, 2], - "inter_token_latencies": [2, 4], + "inter_token_latencies": [2, 1], "output_token_throughputs_per_request": [ 3 / ns_to_sec(7), - 1 / ns_to_sec(3), + 2 / ns_to_sec(3), ], - "output_token_throughputs": [3 / ns_to_sec(5)], - "output_sequence_lengths": [3, 3], + "output_token_throughputs": [9 / ns_to_sec(10)], + "output_sequence_lengths": [3, 6], "input_sequence_lengths": [3, 4], }, ), @@ -466,21 +150,21 @@ def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N "request_latencies": [13, 8], "request_throughputs": [2 / ns_to_sec(15)], "time_to_first_tokens": [2, 3], - "inter_token_latencies": [4, 2], + "inter_token_latencies": [4, 1], "output_token_throughputs_per_request": [ 4 / ns_to_sec(13), - 3 / ns_to_sec(8), + 6 / ns_to_sec(8), ], - "output_token_throughputs": [7 / ns_to_sec(15)], - "output_sequence_lengths": [4, 3], + "output_token_throughputs": [2 / ns_to_sec(3)], + "output_sequence_lengths": [4, 6], "input_sequence_lengths": [3, 4], }, ), ], ) - def test_tensorrtllm_engine_llm_profile_data( + def test_triton_llm_profile_data( self, - mock_read_write: pytest.MonkeyPatch, + mock_json, infer_mode, load_level, expected_metrics, @@ -498,28 +182,28 @@ def test_tensorrtllm_engine_llm_profile_data( - experiment 1: [3 - 1, 4 - 2] = [2, 2] - experiment 2: [7 - 5, 6 - 3] = [2, 3] * inter token latencies - - experiment 1: [((8 - 1) - 2)/(3 - 1), ((11 - 2) - 2)/(3 - 1)] - : [2.5, 3.5] - : [2, 4] # rounded - - experiment 2: [((18 - 5) - 2)/(4 - 1), ((11 - 3) - 3)/(3 - 1)] - : [11/3, 2.5] - : [4, 2] # rounded + - experiment 1: [((8 - 1) - 2)/(3 - 1), ((11 - 2) - 2)/(6 - 1)] + : [2.5, 1.4] + : [2, 1] # rounded + - experiment 2: [((18 - 5) - 2)/(4 - 1), ((11 - 3) - 3)/(6 - 1)] + : [11/3, 1] + : [4, 1] # rounded * output token throughputs per request - - experiment 1: [3/(8 - 1), 3/(11 - 2)] = [3/7, 1/3] - - experiment 2: [4/(18 - 5), 3/(11 - 3)] = [4/13, 3/8] + - experiment 1: [3/(8 - 1), 6/(11 - 2)] = [3/7, 6/9] + - experiment 2: [4/(18 - 5), 6/(11 - 3)] = [4/13, 6/8] * output token throughputs - - experiment 1: [(3 + 3)/(11 - 1)] = [3/5] - - experiment 2: [(4 + 3)/(18 - 3)] = [7/15] + - experiment 1: [(3 + 6)/(11 - 1)] = [9/10] + - experiment 2: [(4 + 6)/(18 - 3)] = [2/3] * output sequence lengths - - experiment 1: [3, 3] - - experiment 2: [4, 3] + - experiment 1: [3, 6] + - experiment 2: [4, 6] * input sequence lengths - experiment 1: [3, 4] - experiment 2: [3, 4] """ tokenizer = get_tokenizer(DEFAULT_TOKENIZER) pd = LLMProfileDataParser( - filename=Path("tensorrtllm_engine_profile_export.json"), + filename=Path("triton_profile_export.json"), tokenizer=tokenizer, ) @@ -536,187 +220,13 @@ def test_tensorrtllm_engine_llm_profile_data( with pytest.raises(KeyError): pd.get_statistics(infer_mode="concurrency", load_level="30") - def test_merged_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: - """Test merging the multiple sse response.""" - res_timestamps = [0, 1, 2, 3] - res_outputs = [ - { - "response": 'data: {"choices":[{"delta":{"content":"aaa"}}],"object":"chat.completion.chunk"}\n\n' - }, - { - "response": ( - 'data: {"choices":[{"delta":{"content":"abc"}}],"object":"chat.completion.chunk"}\n\n' - 'data: {"choices":[{"delta":{"content":"1234"}}],"object":"chat.completion.chunk"}\n\n' - 'data: {"choices":[{"delta":{"content":"helloworld"}}],"object":"chat.completion.chunk"}\n\n' - ) - }, - {"response": "data: [DONE]\n\n"}, - ] - expected_response = '{"choices": [{"delta": {"content": "abc1234helloworld"}}], "object": "chat.completion.chunk"}' - - tokenizer = get_tokenizer(DEFAULT_TOKENIZER) - pd = LLMProfileDataParser( - filename=Path("openai_profile_export.json"), - tokenizer=tokenizer, - ) - - pd._preprocess_response(res_timestamps, res_outputs) - assert res_outputs[1]["response"] == expected_response - - def test_openai_output_token_counts( - self, mock_read_write: pytest.MonkeyPatch - ) -> None: - output_texts = [ - "Ad", - "idas", - " Orig", - "inals", - " are", - " now", - " available", - " in", - " more", - " than", - ] - res_outputs = [] - for text in output_texts: - response = f'data: {{"choices":[{{"delta":{{"content":"{text}"}}}}],"object":"chat.completion.chunk"}}\n\n' - res_outputs.append({"response": response}) - - tokenizer = get_tokenizer(DEFAULT_TOKENIZER) - pd = LLMProfileDataParser( - filename=Path("openai_profile_export.json"), - tokenizer=tokenizer, - ) - - output_token_counts, total_output_token = pd._get_output_token_counts( - res_outputs - ) - assert output_token_counts == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # total 10 - assert total_output_token == 9 - assert total_output_token != sum(output_token_counts) - - def test_triton_output_token_counts( - self, mock_read_write: pytest.MonkeyPatch - ) -> None: - output_texts = [ - "Ad", - "idas", - " Orig", - "inals", - " are", - " now", - " available", - " in", - " more", - " than", - ] - res_outputs = [] - for text in output_texts: - res_outputs.append({"text_output": text}) - - tokenizer = get_tokenizer(DEFAULT_TOKENIZER) - pd = LLMProfileDataParser( - filename=Path("triton_profile_export.json"), - tokenizer=tokenizer, - ) - - output_token_counts, total_output_token = pd._get_output_token_counts( - res_outputs - ) - assert output_token_counts == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # total 10 - assert total_output_token == 9 - assert total_output_token != sum(output_token_counts) - - def test_empty_response(self, mock_read_write: pytest.MonkeyPatch) -> None: - """Check if it handles all empty responses.""" - tokenizer = get_tokenizer(DEFAULT_TOKENIZER) - - # Should not throw error - _ = LLMProfileDataParser( - filename=Path("empty_profile_export.json"), - tokenizer=tokenizer, - ) - - def test_unfinished_responses(self, mock_read_write: pytest.MonkeyPatch) -> None: - """Check if it handles unfinished responses.""" - res_timestamps = [0, 1, 2] - res_outputs = [ - { - "response": 'data: {"id":"8ae835f2ecbb67f3-SJC","object":"chat.completion.chunk","created":1722875835,"choices":[{"index":0,"text"' - }, - { - "response": ':" writing","logprobs":null,"finish_reason":null,"seed":null,"delta":{"token_id":4477,"role":"assistant","content":" writing","tool_calls":null}}],"model":"meta-llama/Llama-3-8b-chat-hf","usage":null}' - }, - {"response": "data: [DONE]\n\n"}, - ] - expected_response = 'data: {"id":"8ae835f2ecbb67f3-SJC","object":"chat.completion.chunk","created":1722875835,"choices":[{"index":0,"text":" writing","logprobs":null,"finish_reason":null,"seed":null,"delta":{"token_id":4477,"role":"assistant","content":" writing","tool_calls":null}}],"model":"meta-llama/Llama-3-8b-chat-hf","usage":null}' - - tokenizer = get_tokenizer(DEFAULT_TOKENIZER) - pd = LLMProfileDataParser( - filename=Path("openai_profile_export.json"), - tokenizer=tokenizer, - ) - - pd._preprocess_response(res_timestamps, res_outputs) - assert res_outputs[0]["response"] == expected_response - - def test_non_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: - """Check if it handles single responses.""" - res_timestamps = [ - 0, - ] - res_outputs = [ - { - "response": '{"id":"1","object":"chat.completion","created":2,"model":"gpt2","choices":[{"index":0,"message":{"role":"assistant","content":"A friend of mine, who is also a cook, writes a blog.","tool_calls":[]},"logprobs":null,"finish_reason":"length","stop_reason":null}],"usage":{"prompt_tokens":47,"total_tokens":1024,"completion_tokens":977}}' - }, - ] - expected_response = '{"id":"1","object":"chat.completion","created":2,"model":"gpt2","choices":[{"index":0,"message":{"role":"assistant","content":"A friend of mine, who is also a cook, writes a blog.","tool_calls":[]},"logprobs":null,"finish_reason":"length","stop_reason":null}],"usage":{"prompt_tokens":47,"total_tokens":1024,"completion_tokens":977}}' - - tokenizer = get_tokenizer(DEFAULT_TOKENIZER) - pd = LLMProfileDataParser( - filename=Path("openai_profile_export.json"), - tokenizer=tokenizer, - ) - - pd._preprocess_response(res_timestamps, res_outputs) - assert res_outputs[0]["response"] == expected_response - - empty_profile_data = { - "service_kind": "openai", - "endpoint": "v1/chat/completions", - "experiments": [ - { - "experiment": { - "mode": "concurrency", - "value": 10, - }, - "requests": [ - { - "timestamp": 1, - "request_inputs": { - "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test"}]}],"model":"llama-2-7b","stream":true}', - }, - "response_timestamps": [3, 5, 8], - "response_outputs": [ - { - "response": 'data: {"id":"abc","object":"chat.completion.chunk","created":123,"model":"llama-2-7b","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' - }, - { - "response": 'data: {"id":"abc","object":"chat.completion.chunk","created":123,"model":"llama-2-7b","choices":[{"index":0,"delta":{"content":""},"finish_reason":null}]}\n\n' - }, - {"response": "data: [DONE]\n\n"}, - ], - }, - ], - }, - ], - } - - openai_profile_data = { - "service_kind": "openai", - "endpoint": "v1/chat/completions", - "experiments": [ + ############################### + # OPENAI CHAT COMPLETIONS + ############################### + openai_profile_data = { + "service_kind": "openai", + "endpoint": "v1/chat/completions", + "experiments": [ { "experiment": { "mode": "concurrency", @@ -780,6 +290,83 @@ def test_non_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: ], } + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=openai_profile_data, + ) + @pytest.mark.parametrize( + "infer_mode, load_level, expected_metrics", + [ + ( + "concurrency", + "10", + { + "request_latencies": [11, 13], + "request_throughputs": [1 / ns_to_sec(7)], + "time_to_first_tokens": [4, 5], + "inter_token_latencies": [4, 2], + "output_token_throughputs_per_request": [ + 3 / ns_to_sec(11), + 6 / ns_to_sec(13), + ], + "output_token_throughputs": [9 / ns_to_sec(14)], + "output_sequence_lengths": [3, 6], + "input_sequence_lengths": [3, 4], + }, + ), + ], + ) + def test_openai_llm_profile_data( + self, + mock_json, + infer_mode, + load_level, + expected_metrics, + ) -> None: + """Collect LLM metrics from profile export data and check values. + + Metrics + * request_latencies + - experiment 1: [12 - 1, 15 - 2] = [11, 13] + * request_throughputs + - experiment 1: [2/(15 - 1)] = [1/7] + * time to first tokens + - experiment 1: [5 - 1, 7 - 2] = [4, 5] + * inter token latencies + - experiment 1: [((12 - 1) - 4)/(3 - 1), ((15 - 2) - 5)/(6 - 1)] + : [3.5, 1.6] + : [4, 2] # rounded + * output token throughputs per request + - experiment 1: [3/(12 - 1), 6/(15 - 2)] = [3/11, 6/13] + * output token throughputs + - experiment 1: [(3 + 6)/(15 - 1)] = [9/14] + * output sequence lengths + - experiment 1: [3, 6] + * input sequence lengths + - experiment 1: [3, 4] + """ + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("openai_profile_export.json"), + tokenizer=tokenizer, + ) + + statistics = pd.get_statistics(infer_mode=infer_mode, load_level=load_level) + metrics = cast(LLMMetrics, statistics.metrics) + + expected_metrics = LLMMetrics(**expected_metrics) + expected_statistics = Statistics(expected_metrics) + + check_llm_metrics(metrics, expected_metrics) + check_statistics(statistics, expected_statistics) + + # check non-existing profile data + with pytest.raises(KeyError): + pd.get_statistics(infer_mode="concurrency", load_level="40") + + ############################### + # OPENAI VISION + ############################### openai_vlm_profile_data = { "service_kind": "openai", "endpoint": "v1/chat/completions", @@ -847,70 +434,83 @@ def test_non_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: ], } - triton_profile_data = { - "service_kind": "triton", - "endpoint": "", - "experiments": [ - { - "experiment": { - "mode": "concurrency", - "value": 10, + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=openai_vlm_profile_data, + ) + @pytest.mark.parametrize( + "infer_mode, load_level, expected_metrics", + [ + ( + "concurrency", + "10", + { + "request_latencies": [11, 13], + "request_throughputs": [1 / ns_to_sec(7)], + "time_to_first_tokens": [4, 5], + "inter_token_latencies": [4, 2], + "output_token_throughputs_per_request": [ + 3 / ns_to_sec(11), + 6 / ns_to_sec(13), + ], + "output_token_throughputs": [9 / ns_to_sec(14)], + "output_sequence_lengths": [3, 6], + "input_sequence_lengths": [3, 4], }, - "requests": [ - { - "timestamp": 1, - "request_inputs": {"text_input": "This is test"}, - "response_timestamps": [3, 5, 8], - "response_outputs": [ - {"text_output": "I"}, - {"text_output": " like"}, - {"text_output": " dogs"}, - ], - }, - { - "timestamp": 2, - "request_inputs": {"text_input": "This is test too"}, - "response_timestamps": [4, 7, 11], - "response_outputs": [ - {"text_output": "I"}, - {"text_output": " don't"}, - {"text_output": " cook food"}, - ], - }, - ], - }, - { - "experiment": { - "mode": "request_rate", - "value": 2.0, - }, - "requests": [ - { - "timestamp": 5, - "request_inputs": {"text_input": "This is test"}, - "response_timestamps": [7, 8, 13, 18], - "response_outputs": [ - {"text_output": "cat"}, - {"text_output": " is"}, - {"text_output": " cool"}, - {"text_output": " too"}, - ], - }, - { - "timestamp": 3, - "request_inputs": {"text_input": "This is test too"}, - "response_timestamps": [6, 8, 11], - "response_outputs": [ - {"text_output": "it's"}, - {"text_output": " very"}, - {"text_output": " simple work"}, - ], - }, - ], - }, + ), ], - } + ) + def test_openai_vlm_profile_data( + self, + mock_json, + infer_mode, + load_level, + expected_metrics, + ) -> None: + """Collect LLM metrics from profile export data and check values. + + Metrics + * request_latencies + - experiment 1: [12 - 1, 15 - 2] = [11, 13] + * request_throughputs + - experiment 1: [2/(15 - 1)] = [1/7] + * time to first tokens + - experiment 1: [5 - 1, 7 - 2] = [4, 5] + * inter token latencies + - experiment 1: [((12 - 1) - 4)/(3 - 1), ((15 - 2) - 5)/(6 - 1)] + : [3.5, 1.6] + : [4, 2] # rounded + * output token throughputs per request + - experiment 1: [3/(12 - 1), 6/(15 - 2)] = [3/11, 6/13] + * output token throughputs + - experiment 1: [(3 + 6)/(15 - 1)] = [9/14] + * output sequence lengths + - experiment 1: [3, 6] + * input sequence lengths + - experiment 1: [3, 4] + """ + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("openai_vlm_profile_export.json"), + tokenizer=tokenizer, + ) + + statistics = pd.get_statistics(infer_mode=infer_mode, load_level=load_level) + metrics = cast(LLMMetrics, statistics.metrics) + + expected_metrics = LLMMetrics(**expected_metrics) + expected_statistics = Statistics(expected_metrics) + + check_llm_metrics(metrics, expected_metrics) + check_statistics(statistics, expected_statistics) + + # check non-existing profile data + with pytest.raises(KeyError): + pd.get_statistics(infer_mode="concurrency", load_level="40") + ############################### + # TENSORRTLLM ENGINE + ############################### tensorrtllm_engine_profile_data = { "service_kind": "triton_c_api", "endpoint": "", @@ -1054,3 +654,301 @@ def test_non_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: }, ], } + + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=tensorrtllm_engine_profile_data, + ) + @pytest.mark.parametrize( + "infer_mode, load_level, expected_metrics", + [ + ( + "concurrency", + "10", + { + "request_latencies": [7, 9], + "request_throughputs": [1 / ns_to_sec(5)], + "time_to_first_tokens": [2, 2], + "inter_token_latencies": [2, 4], + "output_token_throughputs_per_request": [ + 3 / ns_to_sec(7), + 1 / ns_to_sec(3), + ], + "output_token_throughputs": [3 / ns_to_sec(5)], + "output_sequence_lengths": [3, 3], + "input_sequence_lengths": [3, 4], + }, + ), + ( + "request_rate", + "2.0", + { + "request_latencies": [13, 8], + "request_throughputs": [2 / ns_to_sec(15)], + "time_to_first_tokens": [2, 3], + "inter_token_latencies": [4, 2], + "output_token_throughputs_per_request": [ + 4 / ns_to_sec(13), + 3 / ns_to_sec(8), + ], + "output_token_throughputs": [7 / ns_to_sec(15)], + "output_sequence_lengths": [4, 3], + "input_sequence_lengths": [3, 4], + }, + ), + ], + ) + def test_tensorrtllm_engine_llm_profile_data( + self, + mock_json, + infer_mode, + load_level, + expected_metrics, + ) -> None: + """Collect LLM metrics from profile export data and check values. + + Metrics + * request_latencies + - experiment 1: [8 - 1, 11 - 2] = [7, 9] + - experiment 2: [18 - 5, 11 -3] = [13, 8] + * request_throughputs + - experiment 1: [2/(11 - 1)] = [1/5] + - experiment 2: [2/(18 - 3)] = [2/15] + * time to first tokens + - experiment 1: [3 - 1, 4 - 2] = [2, 2] + - experiment 2: [7 - 5, 6 - 3] = [2, 3] + * inter token latencies + - experiment 1: [((8 - 1) - 2)/(3 - 1), ((11 - 2) - 2)/(3 - 1)] + : [2.5, 3.5] + : [2, 4] # rounded + - experiment 2: [((18 - 5) - 2)/(4 - 1), ((11 - 3) - 3)/(3 - 1)] + : [11/3, 2.5] + : [4, 2] # rounded + * output token throughputs per request + - experiment 1: [3/(8 - 1), 3/(11 - 2)] = [3/7, 1/3] + - experiment 2: [4/(18 - 5), 3/(11 - 3)] = [4/13, 3/8] + * output token throughputs + - experiment 1: [(3 + 3)/(11 - 1)] = [3/5] + - experiment 2: [(4 + 3)/(18 - 3)] = [7/15] + * output sequence lengths + - experiment 1: [3, 3] + - experiment 2: [4, 3] + * input sequence lengths + - experiment 1: [3, 4] + - experiment 2: [3, 4] + """ + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("tensorrtllm_engine_profile_export.json"), + tokenizer=tokenizer, + ) + + statistics = pd.get_statistics(infer_mode=infer_mode, load_level=load_level) + metrics = cast(LLMMetrics, statistics.metrics) + + expected_metrics = LLMMetrics(**expected_metrics) + expected_statistics = Statistics(expected_metrics) + + check_llm_metrics(metrics, expected_metrics) + check_statistics(statistics, expected_statistics) + + # check non-existing profile data + with pytest.raises(KeyError): + pd.get_statistics(infer_mode="concurrency", load_level="30") + + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=openai_profile_data, + ) + def test_merged_sse_response(self, mock_json) -> None: + """Test merging the multiple sse response.""" + res_timestamps = [0, 1, 2, 3] + res_outputs = [ + { + "response": 'data: {"choices":[{"delta":{"content":"aaa"}}],"object":"chat.completion.chunk"}\n\n' + }, + { + "response": ( + 'data: {"choices":[{"delta":{"content":"abc"}}],"object":"chat.completion.chunk"}\n\n' + 'data: {"choices":[{"delta":{"content":"1234"}}],"object":"chat.completion.chunk"}\n\n' + 'data: {"choices":[{"delta":{"content":"helloworld"}}],"object":"chat.completion.chunk"}\n\n' + ) + }, + {"response": "data: [DONE]\n\n"}, + ] + expected_response = '{"choices": [{"delta": {"content": "abc1234helloworld"}}], "object": "chat.completion.chunk"}' + + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("openai_profile_export.json"), + tokenizer=tokenizer, + ) + + pd._preprocess_response(res_timestamps, res_outputs) + assert res_outputs[1]["response"] == expected_response + + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=openai_profile_data, + ) + def test_openai_output_token_counts(self, mock_json) -> None: + output_texts = [ + "Ad", + "idas", + " Orig", + "inals", + " are", + " now", + " available", + " in", + " more", + " than", + ] + res_outputs = [] + for text in output_texts: + response = f'data: {{"choices":[{{"delta":{{"content":"{text}"}}}}],"object":"chat.completion.chunk"}}\n\n' + res_outputs.append({"response": response}) + + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("openai_profile_export.json"), + tokenizer=tokenizer, + ) + + output_token_counts, total_output_token = pd._get_output_token_counts( + res_outputs + ) + assert output_token_counts == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # total 10 + assert total_output_token == 9 + assert total_output_token != sum(output_token_counts) + + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=triton_profile_data, + ) + def test_triton_output_token_counts(self, mock_json) -> None: + output_texts = [ + "Ad", + "idas", + " Orig", + "inals", + " are", + " now", + " available", + " in", + " more", + " than", + ] + res_outputs = [] + for text in output_texts: + res_outputs.append({"text_output": text}) + + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("triton_profile_export.json"), + tokenizer=tokenizer, + ) + + output_token_counts, total_output_token = pd._get_output_token_counts( + res_outputs + ) + assert output_token_counts == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # total 10 + assert total_output_token == 9 + assert total_output_token != sum(output_token_counts) + + empty_profile_data = { + "service_kind": "openai", + "endpoint": "v1/chat/completions", + "experiments": [ + { + "experiment": { + "mode": "concurrency", + "value": 10, + }, + "requests": [ + { + "timestamp": 1, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test"}]}],"model":"llama-2-7b","stream":true}', + }, + "response_timestamps": [3, 5, 8], + "response_outputs": [ + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","created":123,"model":"llama-2-7b","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","created":123,"model":"llama-2-7b","choices":[{"index":0,"delta":{"content":""},"finish_reason":null}]}\n\n' + }, + {"response": "data: [DONE]\n\n"}, + ], + }, + ], + }, + ], + } + + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=empty_profile_data, + ) + def test_empty_response(self, mock_json) -> None: + """Check if it handles all empty responses.""" + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + + # Should not throw error + _ = LLMProfileDataParser( + filename=Path("empty_profile_export.json"), + tokenizer=tokenizer, + ) + + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=openai_profile_data, + ) + def test_unfinished_responses(self, mock_json) -> None: + """Check if it handles unfinished responses.""" + res_timestamps = [0, 1, 2] + res_outputs = [ + { + "response": 'data: {"id":"8ae835f2ecbb67f3-SJC","object":"chat.completion.chunk","created":1722875835,"choices":[{"index":0,"text"' + }, + { + "response": ':" writing","logprobs":null,"finish_reason":null,"seed":null,"delta":{"token_id":4477,"role":"assistant","content":" writing","tool_calls":null}}],"model":"meta-llama/Llama-3-8b-chat-hf","usage":null}' + }, + {"response": "data: [DONE]\n\n"}, + ] + expected_response = 'data: {"id":"8ae835f2ecbb67f3-SJC","object":"chat.completion.chunk","created":1722875835,"choices":[{"index":0,"text":" writing","logprobs":null,"finish_reason":null,"seed":null,"delta":{"token_id":4477,"role":"assistant","content":" writing","tool_calls":null}}],"model":"meta-llama/Llama-3-8b-chat-hf","usage":null}' + + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("openai_profile_export.json"), + tokenizer=tokenizer, + ) + + pd._preprocess_response(res_timestamps, res_outputs) + assert res_outputs[0]["response"] == expected_response + + @patch( + "genai_perf.profile_data_parser.profile_data_parser.load_json", + return_value=openai_profile_data, + ) + def test_non_sse_response(self, mock_json) -> None: + """Check if it handles single responses.""" + res_timestamps = [ + 0, + ] + res_outputs = [ + { + "response": '{"id":"1","object":"chat.completion","created":2,"model":"gpt2","choices":[{"index":0,"message":{"role":"assistant","content":"A friend of mine, who is also a cook, writes a blog.","tool_calls":[]},"logprobs":null,"finish_reason":"length","stop_reason":null}],"usage":{"prompt_tokens":47,"total_tokens":1024,"completion_tokens":977}}' + }, + ] + expected_response = '{"id":"1","object":"chat.completion","created":2,"model":"gpt2","choices":[{"index":0,"message":{"role":"assistant","content":"A friend of mine, who is also a cook, writes a blog.","tool_calls":[]},"logprobs":null,"finish_reason":"length","stop_reason":null}],"usage":{"prompt_tokens":47,"total_tokens":1024,"completion_tokens":977}}' + + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("openai_profile_export.json"), + tokenizer=tokenizer, + ) + + pd._preprocess_response(res_timestamps, res_outputs) + assert res_outputs[0]["response"] == expected_response diff --git a/genai-perf/tests/test_utils.py b/genai-perf/tests/test_utils.py new file mode 100644 index 00000000..29982387 --- /dev/null +++ b/genai-perf/tests/test_utils.py @@ -0,0 +1,44 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Union + +import pytest +from genai_perf.metrics.statistics import Statistics + + +def ns_to_sec(ns: int) -> Union[int, float]: + """Convert from nanosecond to second.""" + return ns / 1e9 + + +def check_statistics(s1: Statistics, s2: Statistics) -> None: + s1_dict = s1.stats_dict + s2_dict = s2.stats_dict + for metric in s1_dict.keys(): + for stat_name, value in s1_dict[metric].items(): + if stat_name != "unit": + assert s2_dict[metric][stat_name] == pytest.approx(value)