diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/dataset_retriever.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/dataset_retriever.py index 7322b4698..804365e1f 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/dataset_retriever.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/dataset_retriever.py @@ -28,10 +28,13 @@ from typing import Any, Dict, List import requests +from genai_perf import utils from genai_perf.exceptions import GenAIPerfException +from genai_perf.llm_inputs.inputs_utils import ImageFormat, OutputFormat +from genai_perf.llm_inputs.synthetic_image_generator import SyntheticImageGenerator from genai_perf.llm_inputs.synthetic_prompt_generator import SyntheticPromptGenerator from genai_perf.tokenizer import Tokenizer -from genai_perf.utils import load_json_str +from PIL import Image class DatasetRetriever: @@ -56,26 +59,58 @@ def from_url(url: str, starting_index: int, length: int) -> List[Dict[str, Any]] ] return formatted_rows + # (TMA-2018) decouple output_format from this method @staticmethod - def from_file(file_path: Path) -> List[Dict[str, str]]: - with open(file_path, "r") as file: - data = [load_json_str(line) for line in file] + def from_file(file_path: Path, output_format: OutputFormat) -> List[Dict[str, str]]: + contents = DatasetRetriever._load_file_content(file_path) + + dataset = [] + for content in contents: + data = {"text_input": content.get("text_input", "")} + + if output_format == OutputFormat.OPENAI_VISION: + img_filename = content.get("image", "") + encoded_img = DatasetRetriever._read_image_content(img_filename) + data["image"] = encoded_img + + dataset.append(data) + return dataset - for item in data: - if not isinstance(item, dict): + @staticmethod + def _load_file_content(file_path: Path) -> List[Dict[str, str]]: + contents = [] + with open(file_path, "r") as file: + for line in file: + content = utils.load_json_str(line) + if not isinstance(content, dict): raise GenAIPerfException( "File content is not in the expected format." ) - if "text_input" not in item: - raise GenAIPerfException( - f"Missing 'text_input' field in file item: {item}" - ) - if len(item) != 1: + if "text_input" not in content: raise GenAIPerfException( - f"Field other than 'text_input' field found in file item: {item}" + f"Missing 'text_input' field in file content: {content}" ) + contents.append(content) - return [{"text_input": item["text_input"]} for item in data] + return contents + + @staticmethod + def _read_image_content(filename: str) -> str: + try: + img = Image.open(filename) + except: + raise GenAIPerfException( + f"Error occurred while opening an image file: {filename}" + ) + + if img.format.lower() not in utils.get_enum_names(ImageFormat): + raise GenAIPerfException( + f"Unsupported image format '{img.format}' of " + f"the image '{filename}'." + ) + + img_base64 = utils.encode_image(img, img.format) + return f"data:image/{img.format.lower()};base64,{img_base64}" @staticmethod def from_directory(directory_path: Path) -> Dict: @@ -89,7 +124,7 @@ def from_directory(directory_path: Path) -> Dict: # Get the file name without suffix key = file_path.stem with open(file_path, "r") as file: - data[key] = [load_json_str(line) for line in file] + data[key] = [utils.load_json_str(line) for line in file] # Create rows with keys based on file names without suffix num_entries = len(next(iter(data.values()))) @@ -105,11 +140,29 @@ def from_synthetic( prompt_tokens_mean: int, prompt_tokens_stddev: int, num_of_output_prompts: int, + image_width_mean: int, + image_width_stddev: int, + image_height_mean: int, + image_height_stddev: int, + image_format: ImageFormat, + output_format: OutputFormat, ) -> List[Dict[str, str]]: - synthetic_prompts = [] + synthetic_dataset = [] for _ in range(num_of_output_prompts): - synthetic_prompt = SyntheticPromptGenerator.create_synthetic_prompt( + prompt = SyntheticPromptGenerator.create_synthetic_prompt( tokenizer, prompt_tokens_mean, prompt_tokens_stddev ) - synthetic_prompts.append({"text_input": synthetic_prompt}) - return synthetic_prompts + data = {"text_input": prompt} + + if output_format == OutputFormat.OPENAI_VISION: + image = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=image_width_mean, + image_width_stddev=image_width_stddev, + image_height_mean=image_height_mean, + image_height_stddev=image_height_stddev, + image_format=image_format, + ) + data["image"] = image + + synthetic_dataset.append(data) + return synthetic_dataset diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/inputs_utils.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/inputs_utils.py index 4b7401e0e..7656f0654 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/inputs_utils.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/inputs_utils.py @@ -45,6 +45,7 @@ class OutputFormat(Enum): OPENAI_CHAT_COMPLETIONS = auto() OPENAI_COMPLETIONS = auto() OPENAI_EMBEDDINGS = auto() + OPENAI_VISION = auto() RANKINGS = auto() TENSORRTLLM = auto() VLLM = auto() @@ -53,6 +54,11 @@ def to_lowercase(self): return self.name.lower() +class ImageFormat(Enum): + PNG = auto() + JPEG = auto() + + DEFAULT_STARTING_INDEX = 0 DEFAULT_LENGTH = 100 DEFAULT_TENSORRTLLM_MAX_TOKENS = 256 @@ -63,3 +69,9 @@ def to_lowercase(self): DEFAULT_OUTPUT_TOKENS_MEAN = -1 DEFAULT_OUTPUT_TOKENS_STDDEV = 0 DEFAULT_NUM_PROMPTS = 100 + +# Images +DEFAULT_IMAGE_WIDTH_MEAN = 100 +DEFAULT_IMAGE_WIDTH_STDDEV = 0 +DEFAULT_IMAGE_HEIGHT_MEAN = 100 +DEFAULT_IMAGE_HEIGHT_STDDEV = 0 diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/json_converter.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/json_converter.py index 15d06912d..4b93a36c5 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/json_converter.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/json_converter.py @@ -41,6 +41,7 @@ def to_generic(dataset: List[Dict[str, Any]]) -> Dict: for item in dataset: row_data = { "text_input": item.get("text_input", ""), + "image": item.get("image", ""), "system_prompt": item.get("system_prompt", ""), "response": item.get("response", ""), } diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py index 009a079b3..48e8afdde 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py @@ -32,6 +32,10 @@ from genai_perf.exceptions import GenAIPerfException from genai_perf.llm_inputs.dataset_retriever import DatasetRetriever from genai_perf.llm_inputs.inputs_utils import ( + DEFAULT_IMAGE_HEIGHT_MEAN, + DEFAULT_IMAGE_HEIGHT_STDDEV, + DEFAULT_IMAGE_WIDTH_MEAN, + DEFAULT_IMAGE_WIDTH_STDDEV, DEFAULT_LENGTH, DEFAULT_NUM_PROMPTS, DEFAULT_OUTPUT_TOKENS_MEAN, @@ -40,6 +44,7 @@ DEFAULT_PROMPT_TOKENS_STDDEV, DEFAULT_RANDOM_SEED, DEFAULT_STARTING_INDEX, + ImageFormat, ModelSelectionStrategy, OutputFormat, PromptSource, @@ -76,6 +81,11 @@ def create_llm_inputs( output_tokens_deterministic: bool = False, prompt_tokens_mean: int = DEFAULT_PROMPT_TOKENS_MEAN, prompt_tokens_stddev: int = DEFAULT_PROMPT_TOKENS_STDDEV, + image_width_mean: int = DEFAULT_IMAGE_WIDTH_MEAN, + image_width_stddev: int = DEFAULT_IMAGE_WIDTH_STDDEV, + image_height_mean: int = DEFAULT_IMAGE_HEIGHT_MEAN, + image_height_stddev: int = DEFAULT_IMAGE_HEIGHT_STDDEV, + image_format: ImageFormat = ImageFormat.PNG, random_seed: int = DEFAULT_RANDOM_SEED, num_of_output_prompts: int = DEFAULT_NUM_PROMPTS, add_model_name: bool = False, @@ -101,6 +111,12 @@ def create_llm_inputs( prompt_tokens_mean, prompt_tokens_stddev, num_of_output_prompts, + image_width_mean, + image_width_stddev, + image_height_mean, + image_height_stddev, + image_format, + output_format, ) elif input_type == PromptSource.FILE: input_filename = cast(Path, input_filename) @@ -108,7 +124,7 @@ def create_llm_inputs( # if output_format == OutputFormat.RANKINGS: # dataset = DatasetRetriever.from_directory(input_filename) # else: - dataset = DatasetRetriever.from_file(input_filename) + dataset = DatasetRetriever.from_file(input_filename, output_format) else: raise GenAIPerfException("Input source is not recognized.") @@ -147,6 +163,7 @@ def validate_args( PromptSource.DATASET, ], OutputFormat.RANKINGS: [PromptSource.DATASET, PromptSource.SYNTHETIC], + OutputFormat.OPENAI_VISION: [PromptSource.DATASET], } if input_type in unsupported_combinations.get(output_format, []): diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/output_format_converter.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/output_format_converter.py index 225e1a884..af540879d 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/output_format_converter.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/output_format_converter.py @@ -25,7 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import random -from typing import Dict, List +from typing import Any, Dict, List, Union from genai_perf.exceptions import GenAIPerfException from genai_perf.llm_inputs.inputs_utils import ( @@ -46,6 +46,7 @@ def create(output_format: OutputFormat): converters = { OutputFormat.OPENAI_CHAT_COMPLETIONS: OpenAIChatCompletionsConverter, OutputFormat.OPENAI_COMPLETIONS: OpenAICompletionsConverter, + OutputFormat.OPENAI_VISION: OpenAIChatCompletionsConverter, OutputFormat.OPENAI_EMBEDDINGS: OpenAIEmbeddingsConverter, OutputFormat.RANKINGS: RankingsConverter, OutputFormat.VLLM: VLLMConverter, @@ -105,8 +106,8 @@ def convert( for index, row in enumerate(generic_dataset["rows"]): model = self._select_model_name(model_name, index, model_selection_strategy) - text_content = row["row"]["text_input"] - messages = [{"role": "user", "content": text_content}] + content = self._generate_content(data=row["row"]) + messages = [{"role": "user", "content": content}] payload: Dict = {"messages": messages} if add_model_name: @@ -123,6 +124,28 @@ def convert( return pa_json + def _generate_content( + self, data: Dict[str, str] + ) -> Union[str, List[Dict[str, Any]]]: + """ + Generate either text only or multi-modal content for OpenAI Chat Completions API. + """ + content: str | List[Dict[str, Any]] = data["text_input"] + + # convert into multi-modal content format when image exists + if data["image"]: + content = [ + { + "type": "text", + "text": data["text_input"], + }, + { + "type": "image_url", + "image_url": {"url": data["image"]}, + }, + ] + return content + class OpenAICompletionsConverter(BaseConverter): def convert( diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/dlss.png b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/dlss.png new file mode 100644 index 000000000..cdba23dd3 Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/dlss.png differ diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h100.jpeg b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h100.jpeg new file mode 100644 index 000000000..aee985fdc Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h100.jpeg differ diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h200.jpeg b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h200.jpeg new file mode 100644 index 000000000..eb0633b27 Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h200.jpeg differ diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/jensen.jpeg b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/jensen.jpeg new file mode 100644 index 000000000..c9c831680 Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/jensen.jpeg differ diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/synthetic_image_generator.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/synthetic_image_generator.py new file mode 100644 index 000000000..a8335ba06 --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/synthetic_image_generator.py @@ -0,0 +1,77 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import glob +import random +from pathlib import Path +from typing import Optional + +from genai_perf import utils +from genai_perf.llm_inputs.inputs_utils import ImageFormat +from PIL import Image + + +class SyntheticImageGenerator: + """A simple synthetic image generator that generates multiple synthetic + images from the source images. + """ + + @classmethod + def create_synthetic_image( + cls, + image_width_mean: int, + image_width_stddev: int, + image_height_mean: int, + image_height_stddev: int, + image_format: Optional[ImageFormat] = None, + ) -> str: + """Generate base64 encoded synthetic image using the source images.""" + if image_format is None: + image_format = random.choice(list(ImageFormat)) + width = cls._sample_random_positive_integer( + image_width_mean, image_width_stddev + ) + height = cls._sample_random_positive_integer( + image_height_mean, image_height_stddev + ) + + image = cls._sample_source_image() + image = image.resize(size=(width, height)) + + img_base64 = utils.encode_image(image, image_format.name) + return f"data:image/{image_format.name.lower()};base64,{img_base64}" + + @classmethod + def _sample_source_image(cls): + """Sample one image among the source images.""" + filepath = Path(__file__).parent.resolve() / "source_images" / "*" + filenames = glob.glob(str(filepath)) + return Image.open(random.choice(filenames)) + + @classmethod + def _sample_random_positive_integer(cls, mean: int, stddev: int) -> int: + n = int(abs(random.gauss(mean, stddev))) + return n if n != 0 else 1 # avoid zero diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/main.py b/src/c++/perf_analyzer/genai-perf/genai_perf/main.py index c8880aa29..59e483064 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/main.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/main.py @@ -77,6 +77,11 @@ def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None: output_tokens_mean=args.output_tokens_mean, output_tokens_stddev=args.output_tokens_stddev, output_tokens_deterministic=args.output_tokens_mean_deterministic, + image_width_mean=args.image_width_mean, + image_width_stddev=args.image_width_stddev, + image_height_mean=args.image_height_mean, + image_height_stddev=args.image_height_stddev, + image_format=args.image_format, random_seed=args.random_seed, num_of_output_prompts=args.num_prompts, add_model_name=add_model_name, diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py index 4c33d3502..f8990afae 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py @@ -42,12 +42,17 @@ ) from genai_perf.llm_inputs.inputs_utils import ( DEFAULT_BATCH_SIZE, + DEFAULT_IMAGE_HEIGHT_MEAN, + DEFAULT_IMAGE_HEIGHT_STDDEV, + DEFAULT_IMAGE_WIDTH_MEAN, + DEFAULT_IMAGE_WIDTH_STDDEV, DEFAULT_NUM_PROMPTS, DEFAULT_OUTPUT_TOKENS_MEAN, DEFAULT_OUTPUT_TOKENS_STDDEV, DEFAULT_PROMPT_TOKENS_MEAN, DEFAULT_PROMPT_TOKENS_STDDEV, DEFAULT_RANDOM_SEED, + ImageFormat, ModelSelectionStrategy, OutputFormat, PromptSource, @@ -80,6 +85,7 @@ def to_lowercase(self): _endpoint_type_map = { "chat": "v1/chat/completions", "completions": "v1/completions", + "vision": "v1/chat/completions", "embeddings": "v1/embeddings", "rankings": "v1/ranking", } @@ -121,6 +127,25 @@ def _check_compare_args( return args +def _check_image_input_args( + parser: argparse.ArgumentParser, args: argparse.Namespace +) -> argparse.Namespace: + """ + Sanity check the image input args + """ + if args.image_width_mean <= 0 or args.image_height_mean <= 0: + parser.error( + "Both --image-width-mean and --image-height-mean values must be positive." + ) + if args.image_width_stddev < 0 or args.image_height_stddev < 0: + parser.error( + "Both --image-width-stddev and --image-height-stddev values must be non-negative." + ) + + args = _convert_str_to_enum_entry(args, "image_format", ImageFormat) + return args + + def _check_conditional_args( parser: argparse.ArgumentParser, args: argparse.Namespace ) -> argparse.Namespace: @@ -143,6 +168,10 @@ def _check_conditional_args( args.output_format = OutputFormat.OPENAI_EMBEDDINGS elif args.endpoint_type == "rankings": args.output_format = OutputFormat.RANKINGS + # (TMA-1986) deduce vision format from chat completions + image CLI + # because there's no openai vision endpoint. + elif args.endpoint_type == "vision": + args.output_format = OutputFormat.OPENAI_VISION if args.endpoint is not None: args.endpoint = args.endpoint.lstrip(" /") @@ -417,6 +446,51 @@ def _add_input_args(parser): ) +def _add_image_input_args(parser): + input_group = parser.add_argument_group("Image Input") + + input_group.add_argument( + "--image-width-mean", + type=int, + default=DEFAULT_IMAGE_WIDTH_MEAN, + required=False, + help=f"The mean width of images when generating synthetic image data.", + ) + + input_group.add_argument( + "--image-width-stddev", + type=int, + default=DEFAULT_IMAGE_WIDTH_STDDEV, + required=False, + help=f"The standard deviation of width of images when generating synthetic image data.", + ) + + input_group.add_argument( + "--image-height-mean", + type=int, + default=DEFAULT_IMAGE_HEIGHT_MEAN, + required=False, + help=f"The mean height of images when generating synthetic image data.", + ) + + input_group.add_argument( + "--image-height-stddev", + type=int, + default=DEFAULT_IMAGE_HEIGHT_STDDEV, + required=False, + help=f"The standard deviation of height of images when generating synthetic image data.", + ) + + input_group.add_argument( + "--image-format", + type=str, + choices=utils.get_enum_names(ImageFormat), + required=False, + help=f"The compression format of the images. " + "If format is not selected, format of generated image is selected at random", + ) + + def _add_profile_args(parser): profile_group = parser.add_argument_group("Profiling") load_management_group = profile_group.add_mutually_exclusive_group(required=False) @@ -505,7 +579,7 @@ def _add_endpoint_args(parser): endpoint_group.add_argument( "--endpoint-type", type=str, - choices=["chat", "completions", "embeddings", "rankings"], + choices=["chat", "completions", "vision", "embeddings", "rankings"], required=False, help=f"The endpoint-type to send requests to on the " 'server. This is only used with the "openai" service-kind.', @@ -664,6 +738,7 @@ def _parse_profile_args(subparsers) -> argparse.ArgumentParser: ) _add_endpoint_args(profile) _add_input_args(profile) + _add_image_input_args(profile) _add_profile_args(profile) _add_output_args(profile) _add_other_args(profile) @@ -743,6 +818,7 @@ def refine_args( args = _infer_prompt_source(args) args = _check_model_args(parser, args) args = _check_conditional_args(parser, args) + args = _check_image_input_args(parser, args) args = _check_load_manager_args(args) args = _set_artifact_paths(args) elif args.subcommand == Subcommand.COMPARE.to_lowercase(): diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py index 4ec1bec62..183f21fd2 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py @@ -218,6 +218,9 @@ def _get_openai_input_text(self, req_inputs: dict) -> str: return payload["messages"][0]["content"] elif self._response_format == ResponseFormat.OPENAI_COMPLETIONS: return payload["prompt"] + elif self._response_format == ResponseFormat.OPENAI_VISION: + content = payload["messages"][0]["content"] + return " ".join(c["text"] for c in content if c["type"] == "text") else: raise ValueError( "Failed to parse OpenAI request input in profile export file." diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py index d18d8f6fb..74eb48a23 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py @@ -39,6 +39,7 @@ class ResponseFormat(Enum): OPENAI_CHAT_COMPLETIONS = auto() OPENAI_COMPLETIONS = auto() OPENAI_EMBEDDINGS = auto() + OPENAI_VISION = auto() RANKINGS = auto() TRITON = auto() @@ -59,7 +60,15 @@ def _get_profile_metadata(self, data: dict) -> None: if data["endpoint"] == "rerank": self._response_format = ResponseFormat.HUGGINGFACE_RANKINGS elif data["endpoint"] == "v1/chat/completions": - self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS + # (TPA-66) add PA metadata to deduce the response format instead + # of parsing the request input payload in profile export json + # file. + request = data["experiments"][0]["requests"][0] + request_input = request["request_inputs"]["payload"] + if "image_url" in request_input: + self._response_format = ResponseFormat.OPENAI_VISION + else: + self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS elif data["endpoint"] == "v1/completions": self._response_format = ResponseFormat.OPENAI_COMPLETIONS elif data["endpoint"] == "v1/embeddings": @@ -67,13 +76,17 @@ def _get_profile_metadata(self, data: dict) -> None: elif data["endpoint"] == "v1/ranking": self._response_format = ResponseFormat.RANKINGS else: - # TPA-66: add PA metadata to handle this case + # (TPA-66) add PA metadata to handle this case # When endpoint field is either empty or custom endpoint, fall # back to parsing the response to extract the response format. request = data["experiments"][0]["requests"][0] + request_input = request["request_inputs"]["payload"] response = request["response_outputs"][0]["response"] if "chat.completion" in response: - self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS + if "image_url" in request_input: + self._response_format = ResponseFormat.OPENAI_VISION + else: + self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS elif "text_completion" in response: self._response_format = ResponseFormat.OPENAI_COMPLETIONS elif "embedding" in response: diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py b/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py index 3bd7bccdd..e176058d0 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py @@ -34,10 +34,27 @@ # Skip type checking to avoid mypy error # Issue: https://github.com/python/mypy/issues/10632 import yaml # type: ignore +from PIL import Image logger = logging.getLogger(__name__) +def encode_image(img: Image, format: str): + """Encodes an image into base64 encoding.""" + # Lazy import for vision related endpoints + import base64 + from io import BytesIO + + # JPEG does not support P or RGBA mode (commonly used for PNG) so it needs + # to be converted to RGB before an image can be saved as JPEG format. + if format == "JPEG" and img.mode != "RGB": + img = img.convert("RGB") + + buffered = BytesIO() + img.save(buffered, format=format) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + def remove_sse_prefix(msg: str) -> str: prefix = "data: " if msg.startswith(prefix): diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py b/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py index dbaacc32b..76ef3e321 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py @@ -93,6 +93,11 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s "synthetic_input_tokens_stddev", "subcommand", "tokenizer", + "image_width_mean", + "image_width_stddev", + "image_height_mean", + "image_height_stddev", + "image_format", ] utils.remove_file(args.profile_export_file) diff --git a/src/c++/perf_analyzer/genai-perf/pyproject.toml b/src/c++/perf_analyzer/genai-perf/pyproject.toml index 982ee24b7..68d5e3740 100644 --- a/src/c++/perf_analyzer/genai-perf/pyproject.toml +++ b/src/c++/perf_analyzer/genai-perf/pyproject.toml @@ -59,6 +59,7 @@ dependencies = [ "pytest-mock", "pyyaml", "responses", + "pillow", ] # CLI Entrypoint diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_cli.py b/src/c++/perf_analyzer/genai-perf/tests/test_cli.py index eb891fd02..41751c718 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_cli.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_cli.py @@ -30,6 +30,7 @@ import genai_perf.logging as logging import pytest from genai_perf import __version__, parser +from genai_perf.llm_inputs.inputs_utils import ImageFormat from genai_perf.llm_inputs.llm_inputs import ( ModelSelectionStrategy, OutputFormat, @@ -40,7 +41,7 @@ class TestCLIArguments: # ================================================ - # GENAI-PERF COMMAND + # PROFILE COMMAND # ================================================ expected_help_output = ( "CLI to profile LLMs and Generative AI models with Perf Analyzer" @@ -215,6 +216,23 @@ def test_help_version_arguments_output_and_exit( ["--synthetic-input-tokens-stddev", "7"], {"synthetic_input_tokens_stddev": 7}, ), + ( + ["--image-width-mean", "123"], + {"image_width_mean": 123}, + ), + ( + ["--image-width-stddev", "123"], + {"image_width_stddev": 123}, + ), + ( + ["--image-height-mean", "456"], + {"image_height_mean": 456}, + ), + ( + ["--image-height-stddev", "456"], + {"image_height_stddev": 456}, + ), + (["--image-format", "png"], {"image_format": ImageFormat.PNG}), (["-v"], {"verbose": True}), (["--verbose"], {"verbose": True}), (["-u", "test_url"], {"u": "test_url"}), @@ -732,6 +750,26 @@ def test_prompt_source_assertions(self, monkeypatch, mocker, capsys): captured = capsys.readouterr() assert expected_output in captured.err + @pytest.mark.parametrize( + "args", + [ + # negative numbers + ["--image-width-mean", "-123"], + ["--image-width-stddev", "-34"], + ["--image-height-mean", "-123"], + ["--image-height-stddev", "-34"], + # zeros + ["--image-width-mean", "0"], + ["--image-height-mean", "0"], + ], + ) + def test_positive_image_input_args(self, monkeypatch, args): + combined_args = ["genai-perf", "profile", "-m", "test_model"] + args + monkeypatch.setattr("sys.argv", combined_args) + + with pytest.raises(SystemExit) as excinfo: + parser.parse_args() + # ================================================ # COMPARE SUBCOMMAND # ================================================ diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_dataset_retriever.py b/src/c++/perf_analyzer/genai-perf/tests/test_dataset_retriever.py new file mode 100644 index 000000000..5011ae3f8 --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/tests/test_dataset_retriever.py @@ -0,0 +1,73 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from collections import namedtuple +from unittest.mock import patch + +from genai_perf.llm_inputs.dataset_retriever import DatasetRetriever +from genai_perf.llm_inputs.inputs_utils import ImageFormat, OutputFormat +from genai_perf.tokenizer import DEFAULT_TOKENIZER, get_tokenizer + + +class TestDatasetRetriever: + + @patch( + "genai_perf.llm_inputs.synthetic_prompt_generator.SyntheticPromptGenerator.create_synthetic_prompt", + side_effect=["prompt1", "prompt2", "prompt3"], + ) + @patch( + "genai_perf.llm_inputs.synthetic_image_generator.SyntheticImageGenerator.create_synthetic_image", + side_effect=["image1", "image2", "image3"], + ) + def test_from_synthetic_multi_modal(self, mock_prompts, mock_images): + Data = namedtuple("Data", ["text_input", "image"]) + expected_data = [ + Data(text_input="prompt1", image="image1"), + Data(text_input="prompt2", image="image2"), + Data(text_input="prompt3", image="image3"), + ] + num_prompts = 3 + + dataset = DatasetRetriever.from_synthetic( + tokenizer=get_tokenizer(DEFAULT_TOKENIZER), + prompt_tokens_mean=3, + prompt_tokens_stddev=0, + num_of_output_prompts=num_prompts, + image_width_mean=5, + image_width_stddev=0, + image_height_mean=5, + image_height_stddev=0, + image_format=ImageFormat.PNG, + output_format=OutputFormat.OPENAI_VISION, + ) + + assert len(dataset) == len(expected_data) + + for i, data in enumerate(expected_data): + assert dataset[i] == { + "text_input": data.text_input, + "image": data.image, + } diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py b/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py index e4a29267d..f82e59312 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py @@ -249,6 +249,11 @@ def test_generate_json(self, monkeypatch) -> None: "random_seed": 0, "synthetic_input_tokens_mean": 550, "synthetic_input_tokens_stddev": 0, + "image_width_mean": 100, + "image_width_stddev": 0, + "image_height_mean": 100, + "image_height_stddev": 0, + "image_format": null, "concurrency": 1, "measurement_interval": 10000, "request_rate": null, diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py index bf351b5c2..c3b4d202c 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import random +from collections import namedtuple from pathlib import Path from unittest.mock import mock_open, patch @@ -36,6 +37,7 @@ from genai_perf.llm_inputs.inputs_utils import ( DEFAULT_LENGTH, DEFAULT_STARTING_INDEX, + ImageFormat, ModelSelectionStrategy, OutputFormat, PromptSource, @@ -89,6 +91,7 @@ class TestLlmInputs: ("triton", "tensorrtllm", OutputFormat.TENSORRTLLM), ("openai", "v1/completions", OutputFormat.OPENAI_COMPLETIONS), ("openai", "v1/chat/completions", OutputFormat.OPENAI_CHAT_COMPLETIONS), + ("openai", "v1/chat/completions", OutputFormat.OPENAI_VISION), ] @pytest.fixture @@ -219,12 +222,54 @@ def test_get_input_file_with_multiple_prompts(self, mock_file, mock_exists): for i, prompt in enumerate(expected_prompts): assert pa_json["data"][i]["payload"][0]["messages"][0]["content"] == prompt + @patch("pathlib.Path.exists", return_value=True) + @patch( + "genai_perf.llm_inputs.dataset_retriever.DatasetRetriever._read_image_content", + return_value="data:image/png;base64,...", + ) + @patch( + "builtins.open", + new_callable=mock_open, + read_data=( + '{"text_input": "prompt1", "image": "image1.png"}\n' + '{"text_input": "prompt2", "image": "image2.png"}\n' + '{"text_input": "prompt3", "image": "image3.png"}\n' + ), + ) + def test_get_input_file_with_multi_modal_data( + self, mock_exists, mock_image, mock_file + ): + Data = namedtuple("Data", ["text_input", "image"]) + expected_data = [ + Data(text_input="prompt1", image="data:image/png;base64,..."), + Data(text_input="prompt2", image="data:image/png;base64,..."), + Data(text_input="prompt3", image="data:image/png;base64,..."), + ] + + pa_json = LlmInputs.create_llm_inputs( + input_type=PromptSource.FILE, + input_filename=Path("somefile.txt"), + output_format=OutputFormat.OPENAI_VISION, + model_name=["test_model"], + image_format=ImageFormat.PNG, + ) + + assert pa_json is not None + assert len(pa_json["data"]) == len(expected_data) + for i, data in enumerate(expected_data): + content = pa_json["data"][i]["payload"][0]["messages"][0]["content"] + assert content[0]["type"] == "text" + assert content[0]["text"] == data.text_input + assert content[1]["type"] == "image_url" + assert content[1]["image_url"] == {"url": data.image} + @pytest.mark.parametrize( "input_type, output_format", [ (PromptSource.DATASET, OutputFormat.OPENAI_EMBEDDINGS), (PromptSource.DATASET, OutputFormat.VLLM), (PromptSource.DATASET, OutputFormat.RANKINGS), + (PromptSource.DATASET, OutputFormat.OPENAI_VISION), (PromptSource.SYNTHETIC, OutputFormat.OPENAI_EMBEDDINGS), (PromptSource.SYNTHETIC, OutputFormat.VLLM), (PromptSource.SYNTHETIC, OutputFormat.RANKINGS), diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py index 75976189d..844111ae1 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py @@ -71,6 +71,9 @@ def write(self: Any, content: str) -> int: elif filename == "openai_profile_export.json": tmp_file = StringIO(json.dumps(self.openai_profile_data)) return tmp_file + elif filename == "openai_vlm_profile_export.json": + tmp_file = StringIO(json.dumps(self.openai_vlm_profile_data)) + return tmp_file elif filename == "empty_profile_export.json": tmp_file = StringIO(json.dumps(self.empty_profile_data)) return tmp_file @@ -322,6 +325,90 @@ def test_openai_llm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N with pytest.raises(KeyError): pd.get_statistics(infer_mode="concurrency", load_level="40") + def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> None: + """Collect LLM metrics from profile export data and check values. + Metrics + * time to first tokens + - experiment 1: [5 - 1, 7 - 2] = [4, 5] + * inter token latencies + - experiment 1: [((12 - 1) - 4)/(3 - 1), ((15 - 2) - 5)/(6 - 1)] + : [3.5, 1.6] + : [4, 2] # rounded + * output token throughputs per request + - experiment 1: [3/(12 - 1), 6/(15 - 2)] = [3/11, 6/13] + * output token throughputs + - experiment 1: [(3 + 6)/(15 - 1)] = [9/14] + * output sequence lengths + - experiment 1: [3, 6] + * input sequence lengths + - experiment 1: [3, 4] + """ + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("openai_vlm_profile_export.json"), + tokenizer=tokenizer, + ) + + # experiment 1 statistics + stat_obj = pd.get_statistics(infer_mode="concurrency", load_level="10") + metrics = stat_obj.metrics + stat = stat_obj.stats_dict + assert isinstance(metrics, LLMMetrics) + + assert metrics.time_to_first_tokens == [4, 5] + assert metrics.inter_token_latencies == [4, 2] + ottpr = [3 / ns_to_sec(11), 6 / ns_to_sec(13)] + assert metrics.output_token_throughputs_per_request == pytest.approx(ottpr) + ott = [9 / ns_to_sec(14)] + assert metrics.output_token_throughputs == pytest.approx(ott) + assert metrics.output_sequence_lengths == [3, 6] + assert metrics.input_sequence_lengths == [3, 4] + + assert stat["time_to_first_token"]["avg"] == pytest.approx(4.5) # type: ignore + assert stat["inter_token_latency"]["avg"] == pytest.approx(3) # type: ignore + assert stat["output_token_throughput_per_request"]["avg"] == pytest.approx( # type: ignore + np.mean(ottpr) + ) + assert stat["output_sequence_length"]["avg"] == 4.5 # type: ignore + assert stat["input_sequence_length"]["avg"] == 3.5 # type: ignore + + assert stat["time_to_first_token"]["p50"] == pytest.approx(4.5) # type: ignore + assert stat["inter_token_latency"]["p50"] == pytest.approx(3) # type: ignore + assert stat["output_token_throughput_per_request"]["p50"] == pytest.approx( # type: ignore + np.percentile(ottpr, 50) + ) + assert stat["output_sequence_length"]["p50"] == 4.5 # type: ignore + assert stat["input_sequence_length"]["p50"] == 3.5 # type: ignore + + assert stat["time_to_first_token"]["min"] == pytest.approx(4) # type: ignore + assert stat["inter_token_latency"]["min"] == pytest.approx(2) # type: ignore + min_ottpr = 3 / ns_to_sec(11) + assert stat["output_token_throughput_per_request"]["min"] == pytest.approx(min_ottpr) # type: ignore + assert stat["output_sequence_length"]["min"] == 3 # type: ignore + assert stat["input_sequence_length"]["min"] == 3 # type: ignore + + assert stat["time_to_first_token"]["max"] == pytest.approx(5) # type: ignore + assert stat["inter_token_latency"]["max"] == pytest.approx(4) # type: ignore + max_ottpr = 6 / ns_to_sec(13) + assert stat["output_token_throughput_per_request"]["max"] == pytest.approx(max_ottpr) # type: ignore + assert stat["output_sequence_length"]["max"] == 6 # type: ignore + assert stat["input_sequence_length"]["max"] == 4 # type: ignore + + assert stat["time_to_first_token"]["std"] == np.std([4, 5]) * (1) # type: ignore + assert stat["inter_token_latency"]["std"] == np.std([4, 2]) * (1) # type: ignore + assert stat["output_token_throughput_per_request"]["std"] == pytest.approx( # type: ignore + np.std(ottpr) + ) + assert stat["output_sequence_length"]["std"] == np.std([3, 6]) # type: ignore + assert stat["input_sequence_length"]["std"] == np.std([3, 4]) # type: ignore + + oott = 9 / ns_to_sec(14) + assert stat["output_token_throughput"]["avg"] == pytest.approx(oott) # type: ignore + + # check non-existing profile data + with pytest.raises(KeyError): + pd.get_statistics(infer_mode="concurrency", load_level="40") + def test_merged_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: """Test merging the multiple sse response.""" res_timestamps = [0, 1, 2, 3] @@ -522,6 +609,73 @@ def test_empty_response(self, mock_read_write: pytest.MonkeyPatch) -> None: ], } + openai_vlm_profile_data = { + "service_kind": "openai", + "endpoint": "v1/chat/completions", + "experiments": [ + { + "experiment": { + "mode": "concurrency", + "value": 10, + }, + "requests": [ + { + "timestamp": 1, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abcdef"}}]}],"model":"llava-1.6","stream":true}', + }, + # the first, and the last two responses will be ignored because they have no "content" + "response_timestamps": [3, 5, 8, 12, 13, 14], + "response_outputs": [ + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"I"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" like"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" dogs"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":null}]}\n\n' + }, + {"response": "data: [DONE]\n\n"}, + ], + }, + { + "timestamp": 2, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test too"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abcdef"}}]}],"model":"llava-1.6","stream":true}', + }, + # the first, and the last two responses will be ignored because they have no "content" + "response_timestamps": [4, 7, 11, 15, 18, 19], + "response_outputs": [ + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"I"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"don\'t"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"cook food"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":null}]}\n\n' + }, + {"response": "data: [DONE]\n\n"}, + ], + }, + ], + }, + ], + } + triton_profile_data = { "service_kind": "triton", "endpoint": "", diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_synthetic_image_generator.py b/src/c++/perf_analyzer/genai-perf/tests/test_synthetic_image_generator.py new file mode 100644 index 000000000..2460f98c6 --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/tests/test_synthetic_image_generator.py @@ -0,0 +1,149 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import random +from io import BytesIO + +import pytest +from genai_perf.llm_inputs.synthetic_image_generator import ( + ImageFormat, + SyntheticImageGenerator, +) +from PIL import Image + + +def decode_image(base64_string): + _, data = base64_string.split(",") + decoded_data = base64.b64decode(data) + return Image.open(BytesIO(decoded_data)) + + +@pytest.mark.parametrize( + "expected_image_size", + [ + (100, 100), + (200, 200), + ], +) +def test_different_image_size(expected_image_size): + expected_width, expected_height = expected_image_size + base64_string = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=expected_width, + image_width_stddev=0, + image_height_mean=expected_height, + image_height_stddev=0, + image_format=ImageFormat.PNG, + ) + + image = decode_image(base64_string) + assert image.size == expected_image_size, "image not resized to the target size" + + +def test_negative_size_is_not_selected(): + # exception is raised, when PIL.Image.resize is called with negative values + _ = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=-1, + image_width_stddev=10, + image_height_mean=-1, + image_height_stddev=10, + image_format=ImageFormat.PNG, + ) + + +@pytest.mark.parametrize( + "width_mean, width_stddev, height_mean, height_stddev", + [ + (100, 15, 100, 15), + (123, 10, 456, 7), + ], +) +def test_generator_deterministic(width_mean, width_stddev, height_mean, height_stddev): + random.seed(123) + img1 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=width_mean, + image_width_stddev=width_stddev, + image_height_mean=height_mean, + image_height_stddev=height_stddev, + image_format=ImageFormat.PNG, + ) + + random.seed(123) + img2 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=width_mean, + image_width_stddev=width_stddev, + image_height_mean=height_mean, + image_height_stddev=height_stddev, + image_format=ImageFormat.PNG, + ) + + assert img1 == img2, "generator is nondererministic" + + +@pytest.mark.parametrize("image_format", [ImageFormat.PNG, ImageFormat.JPEG]) +def test_base64_encoding_with_different_formats(image_format): + img_base64 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=100, + image_width_stddev=100, + image_height_mean=100, + image_height_stddev=100, + image_format=image_format, + ) + + # check prefix + expected_prefix = f"data:image/{image_format.name.lower()};base64," + assert img_base64.startswith(expected_prefix), "unexpected prefix" + + # check image format + data = img_base64[len(expected_prefix) :] + img_data = base64.b64decode(data) + img_bytes = BytesIO(img_data) + image = Image.open(img_bytes) + assert image.format == image_format.name + + +def test_random_image_format(): + random.seed(123) + img1 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=100, + image_width_stddev=100, + image_height_mean=100, + image_height_stddev=100, + image_format=None, + ) + + random.seed(456) + img2 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=100, + image_width_stddev=100, + image_height_mean=100, + image_height_stddev=100, + image_format=None, + ) + + # check prefix + assert img1.startswith("data:image/png") + assert img2.startswith("data:image/jpeg")