diff --git a/src/c++/perf_analyzer/genai-perf/docs/multi_modal.md b/src/c++/perf_analyzer/genai-perf/docs/multi_modal.md new file mode 100644 index 000000000..bb9f33c60 --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/docs/multi_modal.md @@ -0,0 +1,122 @@ + + +# Profile Vision-Language Models with GenAI-Perf + +GenAI-Perf allows you to profile Vision-Language Models (VLM) running on +[OpenAI Chat Completions API](https://platform.openai.com/docs/guides/chat-completions)-compatible server +by sending [multi-modal content](https://platform.openai.com/docs/guides/vision) to the server. +Currently, you can send multi-modal contents with GenAI-Perf using the following two approaches: +1. The synthetic data generation approach, where GenAI-Perf generates the multi-modal data for you. +2. The Bring Your Own Data (BYOD) approach, where you provide GenAI-Perf with the data to send. + +Before we dive into the two approaches, +you can start OpenAI API compatible server with a VLM model using following command: + +```bash +docker run --runtime nvidia --gpus all \ + -p 8000:8000 --ipc=host \ + vllm/vllm-openai:latest \ + --model llava-hf/llava-v1.6-mistral-7b-hf --dtype float16 +``` + + +## Approach 1: Synthetic Multi-Modal Data Generation + +GenAI-Perf can generate synthetic multi-modal data such as texts or images using +the parameters provide by the user through CLI. + +```bash +genai-perf profile \ + -m llava-hf/llava-v1.6-mistral-7b-hf \ + --service-kind openai \ + --endpoint-type vision \ + --image-width-mean 512 \ + --image-width-stddev 30 \ + --image-height-mean 512 \ + --image-height-stddev 30 \ + --image-format png \ + --synthetic-input-tokens-mean 100 \ + --synthetic-input-tokens-stddev 0 \ + --streaming +``` + +> [!Note] +> Under the hood, GenAI-Perf generates synthetic images using a few source images +> under the `llm_inputs/source_images` directory. +> If you would like to add/remove/edit the source images, +> you can do so by directly editing the source images under the directory. +> GenAI-Perf will pickup the images under the directory automatically when +> generating the synthetic images. + + +## Approach 2: Bring Your Own Data (BYOD) + +Instead of letting GenAI-Perf create the synthetic data, +you can also provide GenAI-Perf with your own data using +[`--input-file`](../README.md#--input-file-path) CLI option. +The file needs to be in JSONL format and should contain both the prompt and +the filepath to the image to send. + +For instance, an example of input file would look something as following: +```bash +// input.jsonl +{"text_input": "What is in this image?", "image": "path/to/image1.png"} +{"text_input": "What is the color of the dog?", "image": "path/to/image2.jpeg"} +{"text_input": "Describe the scene in the picture.", "image": "path/to/image3.png"} +... +``` + +After you create the file, you can run GenAI-Perf using the following command: + +```bash +genai-perf profile \ + -m llava-hf/llava-v1.6-mistral-7b-hf \ + --service-kind openai \ + --endpoint-type vision \ + --input-file input.jsonl \ + --streaming +``` + +Running GenAI-Perf using either approach will give you an example output that +looks like below: + +```bash + LLM Metrics +┏━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┓ +┃ Statistic ┃ avg ┃ min ┃ max ┃ p99 ┃ p90 ┃ p75 ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━┩ +│ Time to first token (ms) │ 321.05 │ 291.30 │ 537.07 │ 497.88 │ 318.46 │ 317.35 │ +│ Inter token latency (ms) │ 12.28 │ 11.44 │ 12.88 │ 12.87 │ 12.81 │ 12.53 │ +│ Request latency (ms) │ 1,866.23 │ 1,044.70 │ 2,832.22 │ 2,779.63 │ 2,534.64 │ 2,054.03 │ +│ Output sequence length │ 126.68 │ 59.00 │ 204.00 │ 200.58 │ 177.80 │ 147.50 │ +│ Input sequence length │ 100.00 │ 100.00 │ 100.00 │ 100.00 │ 100.00 │ 100.00 │ +└──────────────────────────┴──────────┴──────────┴──────────┴──────────┴──────────┴──────────┘ +Output token throughput (per sec): 67.40 +Request throughput (per sec): 0.53 +``` diff --git a/src/c++/perf_analyzer/genai-perf/docs/tutorial.md b/src/c++/perf_analyzer/genai-perf/docs/tutorial.md index 1a37baf39..15cc53efe 100644 --- a/src/c++/perf_analyzer/genai-perf/docs/tutorial.md +++ b/src/c++/perf_analyzer/genai-perf/docs/tutorial.md @@ -71,7 +71,6 @@ export RELEASE="yy.mm" # e.g. export RELEASE="24.06" docker run -it --net=host --gpus=all nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk # Run GenAI-Perf in the container: -```bash genai-perf profile \ -m gpt2 \ --service-kind triton \ @@ -145,7 +144,6 @@ export RELEASE="yy.mm" # e.g. export RELEASE="24.06" docker run -it --net=host --gpus=1 nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk # Run GenAI-Perf in the container: -```bash genai-perf profile \ -m gpt2 \ --service-kind triton \ @@ -207,7 +205,6 @@ export RELEASE="yy.mm" # e.g. export RELEASE="24.06" docker run -it --net=host --gpus=all nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk # Run GenAI-Perf in the container: -```bash genai-perf profile \ -m gpt2 \ --service-kind openai \ @@ -270,7 +267,6 @@ docker run -it --net=host --gpus=all nvcr.io/nvidia/tritonserver:${RELEASE}-py3- # Run GenAI-Perf in the container: -```bash genai-perf profile \ -m gpt2 \ --service-kind openai \ diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py index 39abc7ece..057c33562 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py @@ -20,11 +20,17 @@ from typing import Any, Dict, List, Optional, Tuple, cast import requests +from genai_perf import utils from genai_perf.constants import CNN_DAILY_MAIL, DEFAULT_INPUT_DATA_JSON, OPEN_ORCA from genai_perf.exceptions import GenAIPerfException +from genai_perf.llm_inputs.synthetic_image_generator import ( + ImageFormat, + SyntheticImageGenerator, +) from genai_perf.llm_inputs.synthetic_prompt_generator import SyntheticPromptGenerator from genai_perf.tokenizer import DEFAULT_TOKENIZER, Tokenizer, get_tokenizer from genai_perf.utils import load_json_str +from PIL import Image from requests import Response @@ -43,6 +49,7 @@ class OutputFormat(Enum): OPENAI_CHAT_COMPLETIONS = auto() OPENAI_COMPLETIONS = auto() OPENAI_EMBEDDINGS = auto() + OPENAI_VISION = auto() RANKINGS = auto() TENSORRTLLM = auto() VLLM = auto() @@ -75,6 +82,11 @@ class LlmInputs: DEFAULT_OUTPUT_TOKENS_STDDEV = 0 DEFAULT_NUM_PROMPTS = 100 + DEFAULT_IMAGE_WIDTH_MEAN = 100 + DEFAULT_IMAGE_WIDTH_STDDEV = 0 + DEFAULT_IMAGE_HEIGHT_MEAN = 100 + DEFAULT_IMAGE_HEIGHT_STDDEV = 0 + EMPTY_JSON_IN_VLLM_PA_FORMAT: Dict = {"data": []} EMPTY_JSON_IN_TENSORRTLLM_PA_FORMAT: Dict = {"data": []} EMPTY_JSON_IN_OPENAI_PA_FORMAT: Dict = {"data": []} @@ -97,6 +109,11 @@ def create_llm_inputs( output_tokens_deterministic: bool = False, prompt_tokens_mean: int = DEFAULT_PROMPT_TOKENS_MEAN, prompt_tokens_stddev: int = DEFAULT_PROMPT_TOKENS_STDDEV, + image_width_mean: int = DEFAULT_IMAGE_WIDTH_MEAN, + image_width_stddev: int = DEFAULT_IMAGE_WIDTH_STDDEV, + image_height_mean: int = DEFAULT_IMAGE_HEIGHT_MEAN, + image_height_stddev: int = DEFAULT_IMAGE_HEIGHT_STDDEV, + image_format: ImageFormat = ImageFormat.PNG, random_seed: int = DEFAULT_RANDOM_SEED, num_of_output_prompts: int = DEFAULT_NUM_PROMPTS, add_model_name: bool = False, @@ -139,6 +156,16 @@ def create_llm_inputs( The standard deviation of the length of the output to generate. This is only used if output_tokens_mean is provided. output_tokens_deterministic: If true, the output tokens will set the minimum and maximum tokens to be equivalent. + image_width_mean: + The mean width of images when generating synthetic image data. + image_width_stddev: + The standard deviation of width of images when generating synthetic image data. + image_height_mean: + The mean height of images when generating synthetic image data. + image_height_stddev: + The standard deviation of height of images when generating synthetic image data. + image_format: + The compression format of the images. batch_size: The number of inputs per request (currently only used for the embeddings and rankings endpoints) @@ -175,6 +202,11 @@ def create_llm_inputs( prompt_tokens_mean, prompt_tokens_stddev, num_of_output_prompts, + image_width_mean, + image_width_stddev, + image_height_mean, + image_height_stddev, + image_format, batch_size, input_filename, ) @@ -210,6 +242,11 @@ def get_generic_dataset_json( prompt_tokens_mean: int, prompt_tokens_stddev: int, num_of_output_prompts: int, + image_width_mean: int, + image_width_stddev: int, + image_height_mean: int, + image_height_stddev: int, + image_format: ImageFormat, batch_size: int, input_filename: Optional[Path], ) -> Dict: @@ -236,6 +273,16 @@ def get_generic_dataset_json( The standard deviation of the length of the prompt to generate num_of_output_prompts: The number of synthetic output prompts to generate + image_width_mean: + The mean width of images when generating synthetic image data. + image_width_stddev: + The standard deviation of width of images when generating synthetic image data. + image_height_mean: + The mean height of images when generating synthetic image data. + image_height_stddev: + The standard deviation of height of images when generating synthetic image data. + image_format: + The compression format of the images. batch_size: The number of inputs per request (currently only used for the embeddings and rankings endpoints) input_filename: @@ -280,6 +327,12 @@ def get_generic_dataset_json( ) else: if input_type == PromptSource.DATASET: + # (TMA-1990) support VLM input from public dataset + if output_format == OutputFormat.OPENAI_VISION: + raise GenAIPerfException( + f"{OutputFormat.OPENAI_VISION.to_lowercase()} currently " + "does not support dataset as input." + ) dataset = cls._get_input_dataset_from_url( dataset_name, starting_index, length ) @@ -292,6 +345,12 @@ def get_generic_dataset_json( prompt_tokens_mean, prompt_tokens_stddev, num_of_output_prompts, + image_width_mean, + image_width_stddev, + image_height_mean, + image_height_stddev, + image_format, + output_format, ) generic_dataset_json = ( cls._convert_input_synthetic_or_file_dataset_to_generic_json( @@ -301,6 +360,9 @@ def get_generic_dataset_json( elif input_type == PromptSource.FILE: input_filename = cast(Path, input_filename) input_file_dataset = cls._get_input_dataset_from_file(input_filename) + input_file_dataset = cls._encode_images_in_input_dataset( + input_file_dataset + ) generic_dataset_json = ( cls._convert_input_synthetic_or_file_dataset_to_generic_json( input_file_dataset @@ -309,6 +371,14 @@ def get_generic_dataset_json( else: raise GenAIPerfException("Input source is not recognized.") + # When the generic_dataset_json contains multi-modal data (e.g. images), + # convert the format of the content to OpenAI multi-modal format: + # see https://platform.openai.com/docs/guides/vision + if output_format == OutputFormat.OPENAI_VISION: + generic_dataset_json = cls._convert_to_openai_multi_modal_content( + generic_dataset_json + ) + return generic_dataset_json @classmethod @@ -405,17 +475,36 @@ def _get_input_dataset_from_synthetic( prompt_tokens_mean: int, prompt_tokens_stddev: int, num_of_output_prompts: int, + image_width_mean: int, + image_width_stddev: int, + image_height_mean: int, + image_height_stddev: int, + image_format: ImageFormat, + output_format: OutputFormat, ) -> Dict[str, Any]: dataset_json: Dict[str, Any] = {} dataset_json["features"] = [{"name": "text_input"}] dataset_json["rows"] = [] for _ in range(num_of_output_prompts): + row: Dict["str", Any] = {"row": {}} synthetic_prompt = cls._create_synthetic_prompt( tokenizer, prompt_tokens_mean, prompt_tokens_stddev, ) - dataset_json["rows"].append({"row": {"text_input": synthetic_prompt}}) + row["row"]["text_input"] = synthetic_prompt + + if output_format == OutputFormat.OPENAI_VISION: + synthetic_image = cls._create_synthetic_image( + image_width_mean=image_width_mean, + image_width_stddev=image_width_stddev, + image_height_mean=image_height_mean, + image_height_stddev=image_height_stddev, + image_format=image_format, + ) + row["row"]["image"] = synthetic_image + + dataset_json["rows"].append(row) return dataset_json @@ -497,29 +586,37 @@ def _add_rows_to_generic_json( @classmethod def _get_input_dataset_from_file(cls, input_filename: Path) -> Dict: """ - Reads the input prompts from a JSONL file and converts them into the required dataset format. + Reads the input prompts and images from a JSONL file and converts them + into the required dataset format. Parameters ---------- input_filename : Path - The path to the input file containing the prompts in JSONL format. + The path to the input file containing the prompts and/or images in + JSONL format. Returns ------- Dict - The dataset in the required format with the prompts read from the file. + The dataset in the required format with the prompts and/or images + read from the file. """ cls.verify_file(input_filename) - input_file_prompts = cls._get_prompts_from_input_file(input_filename) + prompts, images = cls._get_prompts_from_input_file(input_filename) dataset_json: Dict[str, Any] = {} dataset_json["features"] = [{"name": "text_input"}] - dataset_json["rows"] = [ - {"row": {"text_input": prompt}} for prompt in input_file_prompts - ] + dataset_json["rows"] = [] + for prompt, image in zip(prompts, images): + content = {"text_input": prompt} + content.update({"image": image} if image else {}) + dataset_json["rows"].append({"row": content}) + return dataset_json @classmethod - def _get_prompts_from_input_file(cls, input_filename: Path) -> List[str]: + def _get_prompts_from_input_file( + cls, input_filename: Path + ) -> Tuple[List[str], List[str]]: """ Reads the input prompts from a JSONL file and returns a list of prompts. @@ -530,21 +627,63 @@ def _get_prompts_from_input_file(cls, input_filename: Path) -> List[str]: Returns ------- - List[str] - A list of prompts read from the file. + Tuple[List[str], List[str]] + A list of prompts and images read from the file. """ prompts = [] + images = [] with open(input_filename, mode="r", newline=None) as file: for line in file: if line.strip(): prompts.append(load_json_str(line).get("text_input", "").strip()) - return prompts + images.append(load_json_str(line).get("image", "").strip()) + return prompts, images @classmethod def verify_file(cls, input_filename: Path) -> None: if not input_filename.exists(): raise FileNotFoundError(f"The file '{input_filename}' does not exist.") + @classmethod + def _convert_to_openai_multi_modal_content( + cls, generic_dataset_json: Dict[str, List[Dict]] + ) -> Dict[str, List[Dict]]: + """ + Converts to multi-modal content format of OpenAI Chat Completions API. + """ + for row in generic_dataset_json["rows"]: + if row["image"]: + row["text_input"] = [ + { + "type": "text", + "text": row["text_input"], + }, + { + "type": "image_url", + "image_url": {"url": row["image"]}, + }, + ] + + return generic_dataset_json + + @classmethod + def _encode_images_in_input_dataset(cls, input_file_dataset: Dict) -> Dict: + for row in input_file_dataset["rows"]: + filename = row["row"].get("image") + if filename: + img = Image.open(filename) + if img.format.lower() not in utils.get_enum_names(ImageFormat): + raise GenAIPerfException( + f"Unsupported image format '{img.format}' of " + f"the image '{filename}'." + ) + + img_base64 = utils.encode_image(img, img.format) + payload = f"data:image/{img.format.lower()};base64,{img_base64}" + row["row"]["image"] = payload + + return input_file_dataset + @classmethod def _convert_generic_json_to_output_format( cls, @@ -559,7 +698,10 @@ def _convert_generic_json_to_output_format( model_name: list = [], model_selection_strategy: ModelSelectionStrategy = ModelSelectionStrategy.ROUND_ROBIN, ) -> Dict: - if output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS: + if ( + output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS + or output_format == OutputFormat.OPENAI_VISION + ): output_json = cls._convert_generic_json_to_openai_chat_completions_format( generic_dataset, add_model_name, @@ -1424,3 +1566,20 @@ def _create_synthetic_prompt( return SyntheticPromptGenerator.create_synthetic_prompt( tokenizer, prompt_tokens_mean, prompt_tokens_stddev ) + + @classmethod + def _create_synthetic_image( + cls, + image_width_mean: int, + image_width_stddev: int, + image_height_mean: int, + image_height_stddev: int, + image_format: ImageFormat, + ) -> str: + return SyntheticImageGenerator.create_synthetic_image( + image_width_mean=image_width_mean, + image_width_stddev=image_width_stddev, + image_height_mean=image_height_mean, + image_height_stddev=image_height_stddev, + image_format=image_format, + ) diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/dlss.png b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/dlss.png new file mode 100644 index 000000000..cdba23dd3 Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/dlss.png differ diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h100.jpeg b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h100.jpeg new file mode 100644 index 000000000..aee985fdc Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h100.jpeg differ diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h200.jpeg b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h200.jpeg new file mode 100644 index 000000000..eb0633b27 Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/h200.jpeg differ diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/jensen.jpeg b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/jensen.jpeg new file mode 100644 index 000000000..c9c831680 Binary files /dev/null and b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/source_images/jensen.jpeg differ diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/synthetic_image_generator.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/synthetic_image_generator.py new file mode 100644 index 000000000..a2df14d87 --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/synthetic_image_generator.py @@ -0,0 +1,82 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import glob +import random +from enum import Enum, auto +from pathlib import Path +from typing import Optional + +from genai_perf import utils +from PIL import Image + + +class ImageFormat(Enum): + PNG = auto() + JPEG = auto() + + +class SyntheticImageGenerator: + """A simple synthetic image generator that generates multiple synthetic + images from the source images. + """ + + @classmethod + def create_synthetic_image( + cls, + image_width_mean: int, + image_width_stddev: int, + image_height_mean: int, + image_height_stddev: int, + image_format: Optional[ImageFormat] = None, + ) -> str: + """Generate base64 encoded synthetic image using the source images.""" + if image_format is None: + image_format = random.choice(list(ImageFormat)) + width = cls._sample_random_positive_integer( + image_width_mean, image_width_stddev + ) + height = cls._sample_random_positive_integer( + image_height_mean, image_height_stddev + ) + + image = cls._sample_source_image() + image = image.resize(size=(width, height)) + + img_base64 = utils.encode_image(image, image_format.name) + return f"data:image/{image_format.name.lower()};base64,{img_base64}" + + @classmethod + def _sample_source_image(cls): + """Sample one image among the source images.""" + filepath = Path(__file__).parent.resolve() / "source_images" / "*" + filenames = glob.glob(str(filepath)) + return Image.open(random.choice(filenames)) + + @classmethod + def _sample_random_positive_integer(cls, mean: int, stddev: int) -> int: + n = int(abs(random.gauss(mean, stddev))) + return n if n != 0 else 1 # avoid zero diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/main.py b/src/c++/perf_analyzer/genai-perf/genai_perf/main.py index 912ee4725..9ff7b5b9a 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/main.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/main.py @@ -76,6 +76,11 @@ def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None: output_tokens_mean=args.output_tokens_mean, output_tokens_stddev=args.output_tokens_stddev, output_tokens_deterministic=args.output_tokens_mean_deterministic, + image_width_mean=args.image_width_mean, + image_width_stddev=args.image_width_stddev, + image_height_mean=args.image_height_mean, + image_height_stddev=args.image_height_stddev, + image_format=args.image_format, random_seed=args.random_seed, num_of_output_prompts=args.num_prompts, add_model_name=add_model_name, diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py index 901cf6ca2..776535d15 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py @@ -46,6 +46,7 @@ OutputFormat, PromptSource, ) +from genai_perf.llm_inputs.synthetic_image_generator import ImageFormat from genai_perf.plots.plot_config_parser import PlotConfigParser from genai_perf.plots.plot_manager import PlotManager from genai_perf.tokenizer import DEFAULT_TOKENIZER @@ -76,6 +77,7 @@ def to_lowercase(self): "completions": "v1/completions", "embeddings": "v1/embeddings", "rankings": "v1/ranking", + "vision": "v1/chat/completions", } @@ -115,6 +117,25 @@ def _check_compare_args( return args +def _check_image_input_args( + parser: argparse.ArgumentParser, args: argparse.Namespace +) -> argparse.Namespace: + """ + Sanity check the image input args + """ + if args.image_width_mean <= 0 or args.image_height_mean <= 0: + parser.error( + "Both --image-width-mean and --image-height-mean values must be positive." + ) + if args.image_width_stddev < 0 or args.image_height_stddev < 0: + parser.error( + "Both --image-width-stddev and --image-height-stddev values must be non-negative." + ) + + args = _convert_str_to_enum_entry(args, "image_format", ImageFormat) + return args + + def _check_conditional_args( parser: argparse.ArgumentParser, args: argparse.Namespace ) -> argparse.Namespace: @@ -138,6 +159,11 @@ def _check_conditional_args( elif args.endpoint_type == "rankings": args.output_format = OutputFormat.RANKINGS + # (TMA-1986) deduce vision format from chat completions + image CLI + # because there's no openai vision endpoint. + elif args.endpoint_type == "vision": + args.output_format = OutputFormat.OPENAI_VISION + if args.endpoint is not None: args.endpoint = args.endpoint.lstrip(" /") else: @@ -411,6 +437,51 @@ def _add_input_args(parser): ) +def _add_image_input_args(parser): + input_group = parser.add_argument_group("Image Input") + + input_group.add_argument( + "--image-width-mean", + type=int, + default=LlmInputs.DEFAULT_IMAGE_WIDTH_MEAN, + required=False, + help=f"The mean width of images when generating synthetic image data.", + ) + + input_group.add_argument( + "--image-width-stddev", + type=int, + default=LlmInputs.DEFAULT_IMAGE_WIDTH_STDDEV, + required=False, + help=f"The standard deviation of width of images when generating synthetic image data.", + ) + + input_group.add_argument( + "--image-height-mean", + type=int, + default=LlmInputs.DEFAULT_IMAGE_HEIGHT_MEAN, + required=False, + help=f"The mean height of images when generating synthetic image data.", + ) + + input_group.add_argument( + "--image-height-stddev", + type=int, + default=LlmInputs.DEFAULT_IMAGE_HEIGHT_STDDEV, + required=False, + help=f"The standard deviation of height of images when generating synthetic image data.", + ) + + input_group.add_argument( + "--image-format", + type=str, + choices=utils.get_enum_names(ImageFormat), + required=False, + help=f"The compression format of the images. " + "If format is not selected, format of generated image is selected at random", + ) + + def _add_profile_args(parser): profile_group = parser.add_argument_group("Profiling") load_management_group = profile_group.add_mutually_exclusive_group(required=False) @@ -499,7 +570,7 @@ def _add_endpoint_args(parser): endpoint_group.add_argument( "--endpoint-type", type=str, - choices=["chat", "completions", "embeddings", "rankings"], + choices=["chat", "completions", "embeddings", "rankings", "vision"], required=False, help=f"The endpoint-type to send requests to on the " 'server. This is only used with the "openai" service-kind.', @@ -658,6 +729,7 @@ def _parse_profile_args(subparsers) -> argparse.ArgumentParser: ) _add_endpoint_args(profile) _add_input_args(profile) + _add_image_input_args(profile) _add_profile_args(profile) _add_output_args(profile) _add_other_args(profile) @@ -737,6 +809,7 @@ def refine_args( args = _infer_prompt_source(args) args = _check_model_args(parser, args) args = _check_conditional_args(parser, args) + args = _check_image_input_args(parser, args) args = _check_load_manager_args(args) args = _set_artifact_paths(args) elif args.subcommand == Subcommand.COMPARE.to_lowercase(): diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py index 4ec1bec62..183f21fd2 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py @@ -218,6 +218,9 @@ def _get_openai_input_text(self, req_inputs: dict) -> str: return payload["messages"][0]["content"] elif self._response_format == ResponseFormat.OPENAI_COMPLETIONS: return payload["prompt"] + elif self._response_format == ResponseFormat.OPENAI_VISION: + content = payload["messages"][0]["content"] + return " ".join(c["text"] for c in content if c["type"] == "text") else: raise ValueError( "Failed to parse OpenAI request input in profile export file." diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py index d18d8f6fb..74eb48a23 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py @@ -39,6 +39,7 @@ class ResponseFormat(Enum): OPENAI_CHAT_COMPLETIONS = auto() OPENAI_COMPLETIONS = auto() OPENAI_EMBEDDINGS = auto() + OPENAI_VISION = auto() RANKINGS = auto() TRITON = auto() @@ -59,7 +60,15 @@ def _get_profile_metadata(self, data: dict) -> None: if data["endpoint"] == "rerank": self._response_format = ResponseFormat.HUGGINGFACE_RANKINGS elif data["endpoint"] == "v1/chat/completions": - self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS + # (TPA-66) add PA metadata to deduce the response format instead + # of parsing the request input payload in profile export json + # file. + request = data["experiments"][0]["requests"][0] + request_input = request["request_inputs"]["payload"] + if "image_url" in request_input: + self._response_format = ResponseFormat.OPENAI_VISION + else: + self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS elif data["endpoint"] == "v1/completions": self._response_format = ResponseFormat.OPENAI_COMPLETIONS elif data["endpoint"] == "v1/embeddings": @@ -67,13 +76,17 @@ def _get_profile_metadata(self, data: dict) -> None: elif data["endpoint"] == "v1/ranking": self._response_format = ResponseFormat.RANKINGS else: - # TPA-66: add PA metadata to handle this case + # (TPA-66) add PA metadata to handle this case # When endpoint field is either empty or custom endpoint, fall # back to parsing the response to extract the response format. request = data["experiments"][0]["requests"][0] + request_input = request["request_inputs"]["payload"] response = request["response_outputs"][0]["response"] if "chat.completion" in response: - self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS + if "image_url" in request_input: + self._response_format = ResponseFormat.OPENAI_VISION + else: + self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS elif "text_completion" in response: self._response_format = ResponseFormat.OPENAI_COMPLETIONS elif "embedding" in response: diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/test_end_to_end.py b/src/c++/perf_analyzer/genai-perf/genai_perf/test_end_to_end.py deleted file mode 100644 index a44304348..000000000 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/test_end_to_end.py +++ /dev/null @@ -1,92 +0,0 @@ -import itertools -import os -import subprocess -import sys - -# How to run: -# test_end_to_end.py -# Where target is "nim_chat" or "nim_completions" or "vllm_openai" or "triton_tensorrtllm" -# -# For all cases but vllm_openai, it assumes that the server will be on port 9999 -# -# This script will run a sweep of all combinations of values in the testing matrix -# by appending those options on to the genai-perf base command -# - - -testing_matrix = [ - ["--concurrency 1", "--concurrency 32", "--request-rate 1", "--request-rate 32"], - ["--streaming", ""], -] - -base_commands = { - "nim_chat": "genai-perf profile -s 999 -p 20000 -m llama-2-7b-chat -u http://localhost:9999 --service-kind openai --endpoint-type chat", - "nim_completions": "genai-perf profile -s 999 -p 20000 -m llama-2-7b -u http://localhost:9999 --service-kind openai --endpoint-type completions", - "vllm_openai": "genai-perf profile -s 999 -p 20000 -m mistralai/Mistral-7B-v0.1 --service-kind openai --endpoint-type chat", - "triton_tensorrtllm": "genai-perf profile -s 999 -p 20000 -m llama-2-7b -u 0.0.0.0:9999 --service-kind triton --backend tensorrtllm", - "triton_vllm": "genai-perf profile -s 999 -p 20000 -m gpt2_vllm --service-kind triton --backend vllm", -} -testname = "" - -if len(sys.argv) == 2: - # The second element in sys.argv is the input string - testname = sys.argv[1] -else: - options = " ".join(base_commands.keys()) - print(f"This script requires exactly one argument. It must be one of {options}") - exit(1) - -base_command = base_commands[testname] - - -def rename_files(files: list, substr: str) -> None: - for f in files: - name, ext = f.rsplit(".", 1) - # Insert the substring and reassemble the filename - new_filename = f"{testname}__{name}__{substr}.{ext}" - try: - os.rename(f, new_filename) - except FileNotFoundError: - # Just ignore the error, since if PA failed these files may not exist - pass - - -def print_summary(): - # FIXME -- print out a few basic metrics. Maybe from the csv? - pass - - -def sanity_check(): - # FIXME -- add in some sanity checking? Throughput isn't 0? - pass - - -# Loop through all combinations -for combination in itertools.product(*testing_matrix): - options_string = " ".join(combination) - command_with_options = f"{base_command} {options_string}" - command_array = command_with_options.split() - - file_options_string = "__".join(combination) - file_options_string = file_options_string.replace(" ", "") - file_options_string = file_options_string.replace("-", "") - output_file = testname + "__" + file_options_string + ".log" - - with open(output_file, "w") as outfile: - print(f"\nCMD: {command_with_options}") - print(f" Output log is {output_file}") - proc = subprocess.run(command_array, stdout=outfile, stderr=subprocess.STDOUT) - - if proc.returncode != 0: - print(f" Command failed with return code: {proc.returncode}") - else: - print(f" Command executed successfully!") - print_summary() - sanity_check() - - files = [ - "profile_export.json", - "profile_export_genai_pa.csv", - "llm_inputs.json", - ] - rename_files(files, file_options_string) diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py b/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py index 6f66230c4..4b625352a 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/utils.py @@ -34,10 +34,27 @@ # Skip type checking to avoid mypy error # Issue: https://github.com/python/mypy/issues/10632 import yaml # type: ignore +from PIL import Image logger = logging.getLogger(__name__) +def encode_image(img: Image, format: str): + """Encodes an image into base64 encoding.""" + # Lazy import for vision related endpoints + import base64 + from io import BytesIO + + # JPEG does not support P or RGBA mode (commonly used for PNG) so it needs + # to be converted to RGB before an image can be saved as JPEG format. + if format == "JPEG" and img.mode != "RGB": + img = img.convert("RGB") + + buffered = BytesIO() + img.save(buffered, format=format) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + def remove_sse_prefix(msg: str) -> str: prefix = "data: " if msg.startswith(prefix): diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py b/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py index dbaacc32b..76ef3e321 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/wrapper.py @@ -93,6 +93,11 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s "synthetic_input_tokens_stddev", "subcommand", "tokenizer", + "image_width_mean", + "image_width_stddev", + "image_height_mean", + "image_height_stddev", + "image_format", ] utils.remove_file(args.profile_export_file) diff --git a/src/c++/perf_analyzer/genai-perf/pyproject.toml b/src/c++/perf_analyzer/genai-perf/pyproject.toml index 982ee24b7..68d5e3740 100644 --- a/src/c++/perf_analyzer/genai-perf/pyproject.toml +++ b/src/c++/perf_analyzer/genai-perf/pyproject.toml @@ -59,6 +59,7 @@ dependencies = [ "pytest-mock", "pyyaml", "responses", + "pillow", ] # CLI Entrypoint diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_cli.py b/src/c++/perf_analyzer/genai-perf/tests/test_cli.py index eb891fd02..2ef5d52ba 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_cli.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_cli.py @@ -31,16 +31,18 @@ import pytest from genai_perf import __version__, parser from genai_perf.llm_inputs.llm_inputs import ( + ImageFormat, ModelSelectionStrategy, OutputFormat, PromptSource, ) +from genai_perf.llm_inputs.synthetic_image_generator import ImageFormat from genai_perf.parser import PathType class TestCLIArguments: # ================================================ - # GENAI-PERF COMMAND + # PROFILE COMMAND # ================================================ expected_help_output = ( "CLI to profile LLMs and Generative AI models with Perf Analyzer" @@ -215,6 +217,23 @@ def test_help_version_arguments_output_and_exit( ["--synthetic-input-tokens-stddev", "7"], {"synthetic_input_tokens_stddev": 7}, ), + ( + ["--image-width-mean", "123"], + {"image_width_mean": 123}, + ), + ( + ["--image-width-stddev", "123"], + {"image_width_stddev": 123}, + ), + ( + ["--image-height-mean", "456"], + {"image_height_mean": 456}, + ), + ( + ["--image-height-stddev", "456"], + {"image_height_stddev": 456}, + ), + (["--image-format", "png"], {"image_format": ImageFormat.PNG}), (["-v"], {"verbose": True}), (["--verbose"], {"verbose": True}), (["-u", "test_url"], {"u": "test_url"}), @@ -732,6 +751,26 @@ def test_prompt_source_assertions(self, monkeypatch, mocker, capsys): captured = capsys.readouterr() assert expected_output in captured.err + @pytest.mark.parametrize( + "args", + [ + # negative numbers + ["--image-width-mean", "-123"], + ["--image-width-stddev", "-34"], + ["--image-height-mean", "-123"], + ["--image-height-stddev", "-34"], + # zeros + ["--image-width-mean", "0"], + ["--image-height-mean", "0"], + ], + ) + def test_positive_image_input_args(self, monkeypatch, args): + combined_args = ["genai-perf", "profile", "-m", "test_model"] + args + monkeypatch.setattr("sys.argv", combined_args) + + with pytest.raises(SystemExit) as excinfo: + parser.parse_args() + # ================================================ # COMPARE SUBCOMMAND # ================================================ diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py b/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py index e4a29267d..f82e59312 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py @@ -249,6 +249,11 @@ def test_generate_json(self, monkeypatch) -> None: "random_seed": 0, "synthetic_input_tokens_mean": 550, "synthetic_input_tokens_stddev": 0, + "image_width_mean": 100, + "image_width_stddev": 0, + "image_height_mean": 100, + "image_height_stddev": 0, + "image_format": null, "concurrency": 1, "measurement_interval": 10000, "request_rate": null, diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py index c6351918e..028e72849 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py @@ -16,6 +16,7 @@ import os import random import statistics +from collections import namedtuple from pathlib import Path from unittest.mock import mock_open, patch @@ -30,7 +31,9 @@ OutputFormat, PromptSource, ) -from genai_perf.tokenizer import Tokenizer +from genai_perf.llm_inputs.synthetic_image_generator import ImageFormat +from genai_perf.tokenizer import DEFAULT_TOKENIZER, get_tokenizer +from PIL import Image mocked_openorca_data = { "features": [ @@ -78,6 +81,7 @@ class TestLlmInputs: ("triton", "tensorrtllm", OutputFormat.TENSORRTLLM), ("openai", "v1/completions", OutputFormat.OPENAI_COMPLETIONS), ("openai", "v1/chat/completions", OutputFormat.OPENAI_CHAT_COMPLETIONS), + ("openai", "v1/chat/completions", OutputFormat.OPENAI_VISION), ] @pytest.fixture @@ -550,6 +554,94 @@ def test_llm_inputs_with_defaults(self, default_configured_url): # else: # assert False, f"Unsupported output format: {output_format}" + def test_add_image_inputs_openai_vision(self) -> None: + generic_json = { + "rows": [ + {"text_input": "test input one", "image": "test_image1"}, + {"text_input": "test input two", "image": "test_image2"}, + ] + } + + generic_json = LlmInputs._convert_to_openai_multi_modal_content(generic_json) + + row1 = generic_json["rows"][0]["text_input"] + assert row1 == [ + { + "type": "text", + "text": "test input one", + }, + { + "type": "image_url", + "image_url": {"url": "test_image1"}, + }, + ] + + row2 = generic_json["rows"][1]["text_input"] + assert row2 == [ + { + "type": "text", + "text": "test input two", + }, + { + "type": "image_url", + "image_url": {"url": "test_image2"}, + }, + ] + + @patch( + "genai_perf.llm_inputs.llm_inputs.LlmInputs._create_synthetic_prompt", + return_value="This is test prompt", + ) + @patch( + "genai_perf.llm_inputs.llm_inputs.LlmInputs._create_synthetic_image", + return_value="test_image_base64", + ) + @pytest.mark.parametrize( + "output_format", + [ + OutputFormat.OPENAI_CHAT_COMPLETIONS, + OutputFormat.OPENAI_COMPLETIONS, + OutputFormat.OPENAI_EMBEDDINGS, + OutputFormat.RANKINGS, + OutputFormat.OPENAI_VISION, + OutputFormat.VLLM, + OutputFormat.TENSORRTLLM, + ], + ) + def test_get_input_dataset_from_synthetic( + self, mock_prompt, mock_image, output_format + ) -> None: + _placeholder = 123 # dummy value + num_prompts = 3 + + dataset_json = LlmInputs._get_input_dataset_from_synthetic( + tokenizer=get_tokenizer(DEFAULT_TOKENIZER), + prompt_tokens_mean=_placeholder, + prompt_tokens_stddev=_placeholder, + num_of_output_prompts=num_prompts, + image_width_mean=_placeholder, + image_width_stddev=_placeholder, + image_height_mean=_placeholder, + image_height_stddev=_placeholder, + image_format=ImageFormat.PNG, + output_format=output_format, + ) + + assert len(dataset_json["rows"]) == num_prompts + + for i in range(num_prompts): + row = dataset_json["rows"][i]["row"] + + if output_format == OutputFormat.OPENAI_VISION: + assert row == { + "text_input": "This is test prompt", + "image": "test_image_base64", + } + else: + assert row == { + "text_input": "This is test prompt", + } + # def test_trtllm_default_max_tokens(self, default_tokenizer: Tokenizer) -> None: # input_name = "max_tokens" # input_value = 256 @@ -687,6 +779,34 @@ def test_get_input_file_with_multiple_prompts(self, mock_file, mock_exists): for i, prompt in enumerate(expected_prompts): assert dataset["rows"][i]["row"]["text_input"] == prompt + @patch("pathlib.Path.exists", return_value=True) + @patch("PIL.Image.open", return_value=Image.new("RGB", (10, 10))) + @patch( + "builtins.open", + new_callable=mock_open, + read_data=( + '{"text_input": "prompt1", "image": "image1.png"}\n' + '{"text_input": "prompt2", "image": "image2.png"}\n' + '{"text_input": "prompt3", "image": "image3.png"}\n' + ), + ) + def test_get_input_file_with_multi_modal_data( + self, mock_exists, mock_image, mock_file + ): + Data = namedtuple("Data", ["text_input", "image"]) + expected_data = [ + Data(text_input="prompt1", image="image1.png"), + Data(text_input="prompt2", image="image2.png"), + Data(text_input="prompt3", image="image3.png"), + ] + dataset = LlmInputs._get_input_dataset_from_file(Path("somefile.txt")) + + assert dataset is not None + assert len(dataset["rows"]) == len(expected_data) + for i, data in enumerate(expected_data): + assert dataset["rows"][i]["row"]["text_input"] == data.text_input + assert dataset["rows"][i]["row"]["image"] == data.image + @pytest.mark.parametrize( "seed, model_name_list, index,model_selection_strategy,expected_model", [ diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py index 05de5b122..689e366cd 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py @@ -69,6 +69,7 @@ def test_llm_metric_system_metrics(self) -> None: output_sequence_lengths=[3, 4], input_sequence_lengths=[12, 34], ) + sys_metrics = m.system_metrics assert len(sys_metrics) == 2 assert sys_metrics[0].name == "output_token_throughput" diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py index 75976189d..d776a6a85 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py @@ -71,6 +71,9 @@ def write(self: Any, content: str) -> int: elif filename == "openai_profile_export.json": tmp_file = StringIO(json.dumps(self.openai_profile_data)) return tmp_file + elif filename == "openai_vlm_profile_export.json": + tmp_file = StringIO(json.dumps(self.openai_vlm_profile_data)) + return tmp_file elif filename == "empty_profile_export.json": tmp_file = StringIO(json.dumps(self.empty_profile_data)) return tmp_file @@ -322,6 +325,91 @@ def test_openai_llm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N with pytest.raises(KeyError): pd.get_statistics(infer_mode="concurrency", load_level="40") + def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> None: + """Collect LLM metrics from profile export data and check values. + + Metrics + * time to first tokens + - experiment 1: [5 - 1, 7 - 2] = [4, 5] + * inter token latencies + - experiment 1: [((12 - 1) - 4)/(3 - 1), ((15 - 2) - 5)/(6 - 1)] + : [3.5, 1.6] + : [4, 2] # rounded + * output token throughputs per request + - experiment 1: [3/(12 - 1), 6/(15 - 2)] = [3/11, 6/13] + * output token throughputs + - experiment 1: [(3 + 6)/(15 - 1)] = [9/14] + * output sequence lengths + - experiment 1: [3, 6] + * input sequence lengths + - experiment 1: [3, 4] + """ + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("openai_vlm_profile_export.json"), + tokenizer=tokenizer, + ) + + # experiment 1 statistics + stat_obj = pd.get_statistics(infer_mode="concurrency", load_level="10") + metrics = stat_obj.metrics + stat = stat_obj.stats_dict + assert isinstance(metrics, LLMMetrics) + + assert metrics.time_to_first_tokens == [4, 5] + assert metrics.inter_token_latencies == [4, 2] + ottpr = [3 / ns_to_sec(11), 6 / ns_to_sec(13)] + assert metrics.output_token_throughputs_per_request == pytest.approx(ottpr) + ott = [9 / ns_to_sec(14)] + assert metrics.output_token_throughputs == pytest.approx(ott) + assert metrics.output_sequence_lengths == [3, 6] + assert metrics.input_sequence_lengths == [3, 4] + + assert stat["time_to_first_token"]["avg"] == pytest.approx(4.5) # type: ignore + assert stat["inter_token_latency"]["avg"] == pytest.approx(3) # type: ignore + assert stat["output_token_throughput_per_request"]["avg"] == pytest.approx( # type: ignore + np.mean(ottpr) + ) + assert stat["output_sequence_length"]["avg"] == 4.5 # type: ignore + assert stat["input_sequence_length"]["avg"] == 3.5 # type: ignore + + assert stat["time_to_first_token"]["p50"] == pytest.approx(4.5) # type: ignore + assert stat["inter_token_latency"]["p50"] == pytest.approx(3) # type: ignore + assert stat["output_token_throughput_per_request"]["p50"] == pytest.approx( # type: ignore + np.percentile(ottpr, 50) + ) + assert stat["output_sequence_length"]["p50"] == 4.5 # type: ignore + assert stat["input_sequence_length"]["p50"] == 3.5 # type: ignore + + assert stat["time_to_first_token"]["min"] == pytest.approx(4) # type: ignore + assert stat["inter_token_latency"]["min"] == pytest.approx(2) # type: ignore + min_ottpr = 3 / ns_to_sec(11) + assert stat["output_token_throughput_per_request"]["min"] == pytest.approx(min_ottpr) # type: ignore + assert stat["output_sequence_length"]["min"] == 3 # type: ignore + assert stat["input_sequence_length"]["min"] == 3 # type: ignore + + assert stat["time_to_first_token"]["max"] == pytest.approx(5) # type: ignore + assert stat["inter_token_latency"]["max"] == pytest.approx(4) # type: ignore + max_ottpr = 6 / ns_to_sec(13) + assert stat["output_token_throughput_per_request"]["max"] == pytest.approx(max_ottpr) # type: ignore + assert stat["output_sequence_length"]["max"] == 6 # type: ignore + assert stat["input_sequence_length"]["max"] == 4 # type: ignore + + assert stat["time_to_first_token"]["std"] == np.std([4, 5]) * (1) # type: ignore + assert stat["inter_token_latency"]["std"] == np.std([4, 2]) * (1) # type: ignore + assert stat["output_token_throughput_per_request"]["std"] == pytest.approx( # type: ignore + np.std(ottpr) + ) + assert stat["output_sequence_length"]["std"] == np.std([3, 6]) # type: ignore + assert stat["input_sequence_length"]["std"] == np.std([3, 4]) # type: ignore + + oott = 9 / ns_to_sec(14) + assert stat["output_token_throughput"]["avg"] == pytest.approx(oott) # type: ignore + + # check non-existing profile data + with pytest.raises(KeyError): + pd.get_statistics(infer_mode="concurrency", load_level="40") + def test_merged_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: """Test merging the multiple sse response.""" res_timestamps = [0, 1, 2, 3] @@ -522,6 +610,73 @@ def test_empty_response(self, mock_read_write: pytest.MonkeyPatch) -> None: ], } + openai_vlm_profile_data = { + "service_kind": "openai", + "endpoint": "v1/chat/completions", + "experiments": [ + { + "experiment": { + "mode": "concurrency", + "value": 10, + }, + "requests": [ + { + "timestamp": 1, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abcdef"}}]}],"model":"llava-1.6","stream":true}', + }, + # the first, and the last two responses will be ignored because they have no "content" + "response_timestamps": [3, 5, 8, 12, 13, 14], + "response_outputs": [ + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"I"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" like"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" dogs"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":null}]}\n\n' + }, + {"response": "data: [DONE]\n\n"}, + ], + }, + { + "timestamp": 2, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test too"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abcdef"}}]}],"model":"llava-1.6","stream":true}', + }, + # the first, and the last two responses will be ignored because they have no "content" + "response_timestamps": [4, 7, 11, 15, 18, 19], + "response_outputs": [ + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"I"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"don\'t"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"cook food"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":null}]}\n\n' + }, + {"response": "data: [DONE]\n\n"}, + ], + }, + ], + }, + ], + } + triton_profile_data = { "service_kind": "triton", "endpoint": "", diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_synthetic_image_generator.py b/src/c++/perf_analyzer/genai-perf/tests/test_synthetic_image_generator.py new file mode 100644 index 000000000..5a79794bb --- /dev/null +++ b/src/c++/perf_analyzer/genai-perf/tests/test_synthetic_image_generator.py @@ -0,0 +1,123 @@ +import base64 +import random +from io import BytesIO + +import pytest +from genai_perf.llm_inputs.synthetic_image_generator import ( + ImageFormat, + SyntheticImageGenerator, +) +from PIL import Image + + +def decode_image(base64_string): + _, data = base64_string.split(",") + decoded_data = base64.b64decode(data) + return Image.open(BytesIO(decoded_data)) + + +@pytest.mark.parametrize( + "expected_image_size", + [ + (100, 100), + (200, 200), + ], +) +def test_different_image_size(expected_image_size): + expected_width, expected_height = expected_image_size + base64_string = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=expected_width, + image_width_stddev=0, + image_height_mean=expected_height, + image_height_stddev=0, + image_format=ImageFormat.PNG, + ) + + image = decode_image(base64_string) + assert image.size == expected_image_size, "image not resized to the target size" + + +def test_negative_size_is_not_selected(): + # exception is raised, when PIL.Image.resize is called with negative values + _ = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=-1, + image_width_stddev=10, + image_height_mean=-1, + image_height_stddev=10, + image_format=ImageFormat.PNG, + ) + + +@pytest.mark.parametrize( + "width_mean, width_stddev, height_mean, height_stddev", + [ + (100, 15, 100, 15), + (123, 10, 456, 7), + ], +) +def test_generator_deterministic(width_mean, width_stddev, height_mean, height_stddev): + random.seed(123) + img1 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=width_mean, + image_width_stddev=width_stddev, + image_height_mean=height_mean, + image_height_stddev=height_stddev, + image_format=ImageFormat.PNG, + ) + + random.seed(123) + img2 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=width_mean, + image_width_stddev=width_stddev, + image_height_mean=height_mean, + image_height_stddev=height_stddev, + image_format=ImageFormat.PNG, + ) + + assert img1 == img2, "generator is nondererministic" + + +@pytest.mark.parametrize("image_format", [ImageFormat.PNG, ImageFormat.JPEG]) +def test_base64_encoding_with_different_formats(image_format): + img_base64 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=100, + image_width_stddev=100, + image_height_mean=100, + image_height_stddev=100, + image_format=image_format, + ) + + # check prefix + expected_prefix = f"data:image/{image_format.name.lower()};base64," + assert img_base64.startswith(expected_prefix), "unexpected prefix" + + # check image format + data = img_base64[len(expected_prefix) :] + img_data = base64.b64decode(data) + img_bytes = BytesIO(img_data) + image = Image.open(img_bytes) + assert image.format == image_format.name + + +def test_random_image_format(): + random.seed(123) + img1 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=100, + image_width_stddev=100, + image_height_mean=100, + image_height_stddev=100, + image_format=None, + ) + + random.seed(456) + img2 = SyntheticImageGenerator.create_synthetic_image( + image_width_mean=100, + image_width_stddev=100, + image_height_mean=100, + image_height_stddev=100, + image_format=None, + ) + + # check prefix + assert img1.startswith("data:image/png") + assert img2.startswith("data:image/jpeg") diff --git a/src/c++/perf_analyzer/test_command_line_parser.cc b/src/c++/perf_analyzer/test_command_line_parser.cc index 2d17bbc24..bebf3caec 100644 --- a/src/c++/perf_analyzer/test_command_line_parser.cc +++ b/src/c++/perf_analyzer/test_command_line_parser.cc @@ -373,9 +373,12 @@ CheckValidRange( std::vector& args, char* option_name, TestCLParser& parser, PAParamsPtr& act, bool& using_range, Range& range, size_t* max_threads) + PAParamsPtr& act, bool& using_range, Range& range, + size_t* max_threads) { SUBCASE("start:end provided") { + *max_threads = 400; *max_threads = 400; args.push_back(option_name); args.push_back("100:400"); // start:end @@ -394,6 +397,7 @@ CheckValidRange( SUBCASE("start:end:step provided") { + *max_threads = 400; *max_threads = 400; args.push_back(option_name); args.push_back("100:400:10"); // start:end:step