Skip to content

Commit

Permalink
Add profile data parser for image retriever models (#43)
Browse files Browse the repository at this point in the history
* add profile data parser for image retriever

* add checks for metrics
  • Loading branch information
nv-hwoo authored Aug 20, 2024
1 parent 2a375f5 commit 5779fb7
Show file tree
Hide file tree
Showing 11 changed files with 901 additions and 655 deletions.
2 changes: 2 additions & 0 deletions genai-perf/genai_perf/export_data/console_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def _get_title(self):
return "Embeddings Metrics"
elif self._args.endpoint_type == "rankings":
return "Rankings Metrics"
elif self._args.endpoint_type == "image_retrieval":
return "Image Retrieval Metrics"
else:
return "LLM Metrics"

Expand Down
8 changes: 7 additions & 1 deletion genai-perf/genai_perf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
from genai_perf.llm_inputs.llm_inputs import LlmInputs
from genai_perf.plots.plot_config_parser import PlotConfigParser
from genai_perf.plots.plot_manager import PlotManager
from genai_perf.profile_data_parser import LLMProfileDataParser, ProfileDataParser
from genai_perf.profile_data_parser import (
ImageRetrievalProfileDataParser,
LLMProfileDataParser,
ProfileDataParser,
)
from genai_perf.tokenizer import Tokenizer, get_tokenizer


Expand Down Expand Up @@ -95,6 +99,8 @@ def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None:
def calculate_metrics(args: Namespace, tokenizer: Tokenizer) -> ProfileDataParser:
if args.endpoint_type in ["embeddings", "rankings"]:
return ProfileDataParser(args.profile_export_file)
elif args.endpoint_type == "image_retrieval":
return ImageRetrievalProfileDataParser(args.profile_export_file)
else:
return LLMProfileDataParser(
filename=args.profile_export_file,
Expand Down
1 change: 1 addition & 0 deletions genai-perf/genai_perf/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from genai_perf.metrics.image_retrieval_metrics import ImageRetrievalMetrics
from genai_perf.metrics.llm_metrics import LLMMetrics
from genai_perf.metrics.metrics import MetricMetadata, Metrics
from genai_perf.metrics.statistics import Statistics
Expand Down
60 changes: 60 additions & 0 deletions genai-perf/genai_perf/metrics/image_retrieval_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import List

from genai_perf.metrics.metrics import MetricMetadata, Metrics


class ImageRetrievalMetrics(Metrics):
"""A simple dataclass that holds core Image Retrieval performance metrics."""

IMAGE_RETRIEVAL_REQUEST_METRICS = [
MetricMetadata("image_throughput", "images/sec"),
MetricMetadata("image_latency", "ms/image"),
]

def __init__(
self,
request_throughputs: List[float] = [],
request_latencies: List[int] = [],
image_throughputs: List[int] = [],
image_latencies: List[int] = [],
) -> None:
super().__init__(request_throughputs, request_latencies)
self.image_throughputs = image_throughputs
self.image_latencies = image_latencies

# add base name mapping
self._base_names["image_throughputs"] = "image_throughput"
self._base_names["image_latencies"] = "image_latency"

@property
def request_metrics(self) -> List[MetricMetadata]:
base_metrics = super().request_metrics # base metrics
return base_metrics + self.IMAGE_RETRIEVAL_REQUEST_METRICS
3 changes: 3 additions & 0 deletions genai-perf/genai_perf/metrics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def _add_units(self, key) -> None:
self._stats_dict[key]["unit"] = "ms"
elif key == "request_throughput":
self._stats_dict[key]["unit"] = "requests/sec"
elif key == "image_throughput":
self._stats_dict[key]["unit"] = "pages/sec"
elif key.startswith("output_token_throughput"):
self._stats_dict[key]["unit"] = "tokens/sec"
elif "sequence_length" in key:
Expand Down Expand Up @@ -168,6 +170,7 @@ def _is_time_metric(self, field: str) -> bool:
"inter_token_latency",
"time_to_first_token",
"request_latency",
"image_latency",
]
return field in time_metrics

Expand Down
3 changes: 3 additions & 0 deletions genai-perf/genai_perf/profile_data_parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from genai_perf.profile_data_parser.image_retrieval_profile_data_parser import (
ImageRetrievalProfileDataParser,
)
from genai_perf.profile_data_parser.llm_profile_data_parser import LLMProfileDataParser
from genai_perf.profile_data_parser.profile_data_parser import (
ProfileDataParser,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from pathlib import Path

from genai_perf.metrics import ImageRetrievalMetrics
from genai_perf.profile_data_parser.profile_data_parser import ProfileDataParser
from genai_perf.utils import load_json_str


class ImageRetrievalProfileDataParser(ProfileDataParser):
"""Calculate and aggregate all the Image Retrieval performance statistics
across the Perf Analyzer profile results.
"""

def __init__(self, filename: Path) -> None:
super().__init__(filename)

def _parse_requests(self, requests: dict) -> ImageRetrievalMetrics:
"""Parse each request in profile data to extract core metrics."""
min_req_timestamp, max_res_timestamp = float("inf"), 0
request_latencies = []
image_throughputs = []
image_latencies = []

for request in requests:
req_timestamp = request["timestamp"]
res_timestamps = request["response_timestamps"]
req_inputs = request["request_inputs"]

# track entire benchmark duration
min_req_timestamp = min(min_req_timestamp, req_timestamp)
max_res_timestamp = max(max_res_timestamp, res_timestamps[-1])

# request latencies
req_latency_ns = res_timestamps[-1] - req_timestamp
request_latencies.append(req_latency_ns)

payload = load_json_str(req_inputs["payload"])
contents = payload["messages"][0]["content"]
num_images = len([c for c in contents if c["type"] == "image_url"])

# image throughput
req_latency_s = req_latency_ns / 1e9 # to seconds
image_throughputs.append(num_images / req_latency_s)

# image latencies
image_latencies.append(req_latency_ns / num_images)

# request throughput
benchmark_duration = (max_res_timestamp - min_req_timestamp) / 1e9 # to seconds
request_throughputs = [len(requests) / benchmark_duration]

return ImageRetrievalMetrics(
request_throughputs,
request_latencies,
image_throughputs,
image_latencies,
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ResponseFormat(Enum):
OPENAI_EMBEDDINGS = auto()
OPENAI_VISION = auto()
RANKINGS = auto()
IMAGE_RETRIEVAL = auto()
TRITON = auto()


Expand Down Expand Up @@ -75,6 +76,8 @@ def _get_profile_metadata(self, data: dict) -> None:
self._response_format = ResponseFormat.OPENAI_EMBEDDINGS
elif data["endpoint"] == "v1/ranking":
self._response_format = ResponseFormat.RANKINGS
elif data["endpoint"] == "v1/infer":
self._response_format = ResponseFormat.IMAGE_RETRIEVAL
else:
# (TPA-66) add PA metadata to handle this case
# When endpoint field is either empty or custom endpoint, fall
Expand All @@ -93,6 +96,8 @@ def _get_profile_metadata(self, data: dict) -> None:
self._response_format = ResponseFormat.OPENAI_EMBEDDINGS
elif "ranking" in response:
self._response_format = ResponseFormat.RANKINGS
elif "image_retrieval" in response:
self._response_format = ResponseFormat.IMAGE_RETRIEVAL
else:
raise RuntimeError("Unknown OpenAI response format.")

Expand Down
140 changes: 140 additions & 0 deletions genai-perf/tests/test_image_retrieval_profile_data_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
from pathlib import Path
from typing import cast
from unittest.mock import mock_open, patch

import pytest
from genai_perf.metrics import ImageRetrievalMetrics, Statistics
from genai_perf.profile_data_parser import ImageRetrievalProfileDataParser

from .test_utils import check_statistics, ns_to_sec


def check_image_retrieval_metrics(
m1: ImageRetrievalMetrics, m2: ImageRetrievalMetrics
) -> None:
assert m1.request_latencies == m2.request_latencies
assert m1.request_throughputs == pytest.approx(m2.request_throughputs)
assert m1.image_latencies == m2.image_latencies
assert m1.image_throughputs == pytest.approx(m2.image_throughputs)


class TestImageRetrievalProfileDataParser:

image_retrieval_profile_data = {
"experiments": [
{
"experiment": {"mode": "concurrency", "value": 10},
"requests": [
{
"timestamp": 1,
"request_inputs": {
"payload": '{"messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"image1"}},{"type":"image_url","image_url":{"url":"image2"}}]}],"model":"yolox"}'
},
"response_timestamps": [3],
"response_outputs": [
{
"response": '{"object":"list","data":[],"model":"yolox","usage":null}'
}
],
},
{
"timestamp": 3,
"request_inputs": {
"payload": '{"messages":[{"role":"user","content":[{"type":"image_url","image_url":{"url":"image1"}},{"type":"image_url","image_url":{"url":"image2"}},{"type":"image_url","image_url":{"url":"image3"}}]}],"model":"yolox"}'
},
"response_timestamps": [7],
"response_outputs": [
{
"response": '{"object":"list","data":[],"model":"yolox","usage":null}'
}
],
},
],
}
],
"version": "",
"service_kind": "openai",
"endpoint": "v1/infer",
}

@patch("pathlib.Path.exists", return_value=True)
@patch(
"builtins.open",
new_callable=mock_open,
read_data=json.dumps(image_retrieval_profile_data),
)
@pytest.mark.parametrize(
"infer_mode, load_level, expected_metrics",
[
(
"concurrency",
"10",
{
"request_throughputs": [1 / ns_to_sec(3)],
"request_latencies": [2, 4],
"image_throughputs": [1 / ns_to_sec(1), 3 / ns_to_sec(4)],
"image_latencies": [1, 4 / 3],
},
),
],
)
def test_image_retrieval_profile_data(
self,
mock_exists,
mock_file,
infer_mode,
load_level,
expected_metrics,
) -> None:
"""Collect image retrieval metrics from profile export data and check values.
Metrics
* request throughputs
- [2 / (7 - 1)] = [1 / ns_to_sec(3)]
* request latencies
- [3 - 1, 7 - 3] = [2, 4]
* image throughputs
- [2 / (3 - 1), 3 / (7 - 3)] = [1 / ns_to_sec(1), 3 / ns_to_sec(4)]
* image latencies
- [(3 - 1) / 2, (7 - 3) / 3] = [1, 4/3]
"""
pd = ImageRetrievalProfileDataParser(
filename=Path("image_retrieval_profile_export.json")
)

statistics = pd.get_statistics(infer_mode="concurrency", load_level="10")
metrics = cast(ImageRetrievalMetrics, statistics.metrics)

expected_metrics = ImageRetrievalMetrics(**expected_metrics)
expected_statistics = Statistics(expected_metrics)

check_image_retrieval_metrics(metrics, expected_metrics)
check_statistics(statistics, expected_statistics)
Loading

0 comments on commit 5779fb7

Please sign in to comment.