Skip to content

Commit

Permalink
Add command line arguments for synthetic image generation (#753)
Browse files Browse the repository at this point in the history
* Add CLI options for synthetic image generation

* read image format from file when --input-file is used

* move encode_image method to utils

* Lazy import some modules
  • Loading branch information
nv-hwoo authored Jul 15, 2024
1 parent 92b2f3d commit 8e5570e
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import json
import random
from copy import deepcopy
from enum import Enum, auto
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, cast

import requests
from genai_perf import utils
from genai_perf.constants import CNN_DAILY_MAIL, DEFAULT_INPUT_DATA_JSON, OPEN_ORCA
from genai_perf.exceptions import GenAIPerfException
from genai_perf.llm_inputs.synthetic_prompt_generator import SyntheticPromptGenerator
Expand Down Expand Up @@ -121,6 +120,11 @@ class LlmInputs:
DEFAULT_OUTPUT_TOKENS_STDDEV = 0
DEFAULT_NUM_PROMPTS = 100

DEFAULT_IMAGE_WIDTH_MEAN = 100
DEFAULT_IMAGE_WIDTH_STDDEV = 0
DEFAULT_IMAGE_HEIGHT_MEAN = 100
DEFAULT_IMAGE_HEIGHT_STDDEV = 0

EMPTY_JSON_IN_VLLM_PA_FORMAT: Dict = {"data": []}
EMPTY_JSON_IN_TENSORRTLLM_PA_FORMAT: Dict = {"data": []}
EMPTY_JSON_IN_OPENAI_PA_FORMAT: Dict = {"data": []}
Expand All @@ -143,6 +147,11 @@ def create_llm_inputs(
output_tokens_deterministic: bool = False,
prompt_tokens_mean: int = DEFAULT_PROMPT_TOKENS_MEAN,
prompt_tokens_stddev: int = DEFAULT_PROMPT_TOKENS_STDDEV,
image_width_mean: int = DEFAULT_IMAGE_WIDTH_MEAN,
image_width_stddev: int = DEFAULT_IMAGE_WIDTH_STDDEV,
image_height_mean: int = DEFAULT_IMAGE_HEIGHT_MEAN,
image_height_stddev: int = DEFAULT_IMAGE_HEIGHT_STDDEV,
image_format: ImageFormat = ImageFormat.PNG,
random_seed: int = DEFAULT_RANDOM_SEED,
num_of_output_prompts: int = DEFAULT_NUM_PROMPTS,
add_model_name: bool = False,
Expand Down Expand Up @@ -185,6 +194,16 @@ def create_llm_inputs(
The standard deviation of the length of the output to generate. This is only used if output_tokens_mean is provided.
output_tokens_deterministic:
If true, the output tokens will set the minimum and maximum tokens to be equivalent.
image_width_mean:
The mean width of images when generating synthetic image data.
image_width_stddev:
The standard deviation of width of images when generating synthetic image data.
image_height_mean:
The mean height of images when generating synthetic image data.
image_height_stddev:
The standard deviation of height of images when generating synthetic image data.
image_format:
The compression format of the images.
batch_size:
The number of inputs per request (currently only used for the embeddings and rankings endpoints)
Expand Down Expand Up @@ -221,6 +240,11 @@ def create_llm_inputs(
prompt_tokens_mean,
prompt_tokens_stddev,
num_of_output_prompts,
image_width_mean,
image_width_stddev,
image_height_mean,
image_height_stddev,
image_format,
batch_size,
input_filename,
)
Expand Down Expand Up @@ -256,6 +280,11 @@ def get_generic_dataset_json(
prompt_tokens_mean: int,
prompt_tokens_stddev: int,
num_of_output_prompts: int,
image_width_mean: int,
image_width_stddev: int,
image_height_mean: int,
image_height_stddev: int,
image_format: ImageFormat,
batch_size: int,
input_filename: Optional[Path],
) -> Dict:
Expand All @@ -282,6 +311,16 @@ def get_generic_dataset_json(
The standard deviation of the length of the prompt to generate
num_of_output_prompts:
The number of synthetic output prompts to generate
image_width_mean:
The mean width of images when generating synthetic image data.
image_width_stddev:
The standard deviation of width of images when generating synthetic image data.
image_height_mean:
The mean height of images when generating synthetic image data.
image_height_stddev:
The standard deviation of height of images when generating synthetic image data.
image_format:
The compression format of the images.
batch_size:
The number of inputs per request (currently only used for the embeddings and rankings endpoints)
input_filename:
Expand Down Expand Up @@ -653,18 +692,17 @@ def _encode_images_in_input_dataset(cls, input_file_dataset: Dict) -> Dict:
filename = row["row"].get("image")
if filename:
img = Image.open(filename)
# (TMA-1985) Support multiple image formats
img_base64 = cls._encode_image(img, ImageFormat.PNG)
row["row"]["image"] = f"data:image/png;base64,{img_base64}"
if img.format.lower() not in utils.get_enum_names(ImageFormat):
raise GenAIPerfException(
f"Unsupported image format '{img.format}' of "
f"the image '{filename}'."
)

return input_file_dataset
img_base64 = utils.encode_image(img, img.format)
payload = f"data:image/{img.format.lower()};base64,{img_base64}"
row["row"]["image"] = payload

@classmethod
def _encode_image(cls, img: Image.Image, format=ImageFormat.PNG):
"""Encodes an image into base64 encoding."""
buffered = BytesIO()
img.save(buffered, format=format.name)
return base64.b64encode(buffered.getvalue()).decode("utf-8")
return input_file_dataset

@classmethod
def _convert_generic_json_to_output_format(
Expand Down
5 changes: 5 additions & 0 deletions src/c++/perf_analyzer/genai-perf/genai_perf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None:
output_tokens_mean=args.output_tokens_mean,
output_tokens_stddev=args.output_tokens_stddev,
output_tokens_deterministic=args.output_tokens_mean_deterministic,
image_width_mean=args.image_width_mean,
image_width_stddev=args.image_width_stddev,
image_height_mean=args.image_height_mean,
image_height_stddev=args.image_height_stddev,
image_format=args.image_format,
random_seed=args.random_seed,
num_of_output_prompts=args.num_prompts,
add_model_name=add_model_name,
Expand Down
67 changes: 67 additions & 0 deletions src/c++/perf_analyzer/genai-perf/genai_perf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
OPEN_ORCA,
)
from genai_perf.llm_inputs.llm_inputs import (
ImageFormat,
LlmInputs,
ModelSelectionStrategy,
OutputFormat,
Expand Down Expand Up @@ -116,6 +117,25 @@ def _check_compare_args(
return args


def _check_image_input_args(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> argparse.Namespace:
"""
Sanity check the image input args
"""
if args.image_width_mean <= 0 or args.image_height_mean <= 0:
parser.error(
"Both --image-width-mean and --image-height-mean values must be positive."
)
if args.image_width_stddev < 0 or args.image_height_stddev < 0:
parser.error(
"Both --image-width-stddev and --image-height-stddev values must be non-negative."
)

args = _convert_str_to_enum_entry(args, "image_format", ImageFormat)
return args


def _check_conditional_args(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> argparse.Namespace:
Expand Down Expand Up @@ -417,6 +437,51 @@ def _add_input_args(parser):
)


def _add_image_input_args(parser):
input_group = parser.add_argument_group("Image Input")

input_group.add_argument(
"--image-width-mean",
type=int,
default=LlmInputs.DEFAULT_IMAGE_WIDTH_MEAN,
required=False,
help=f"The mean width of images when generating synthetic image data.",
)

input_group.add_argument(
"--image-width-stddev",
type=int,
default=LlmInputs.DEFAULT_IMAGE_WIDTH_STDDEV,
required=False,
help=f"The standard deviation of width of images when generating synthetic image data.",
)

input_group.add_argument(
"--image-height-mean",
type=int,
default=LlmInputs.DEFAULT_IMAGE_HEIGHT_MEAN,
required=False,
help=f"The mean height of images when generating synthetic image data.",
)

input_group.add_argument(
"--image-height-stddev",
type=int,
default=LlmInputs.DEFAULT_IMAGE_HEIGHT_STDDEV,
required=False,
help=f"The standard deviation of height of images when generating synthetic image data.",
)

input_group.add_argument(
"--image-format",
type=str,
choices=utils.get_enum_names(ImageFormat),
default="png",
required=False,
help=f"The compression format of the images.",
)


def _add_profile_args(parser):
profile_group = parser.add_argument_group("Profiling")
load_management_group = profile_group.add_mutually_exclusive_group(required=False)
Expand Down Expand Up @@ -664,6 +729,7 @@ def _parse_profile_args(subparsers) -> argparse.ArgumentParser:
)
_add_endpoint_args(profile)
_add_input_args(profile)
_add_image_input_args(profile)
_add_profile_args(profile)
_add_output_args(profile)
_add_other_args(profile)
Expand Down Expand Up @@ -743,6 +809,7 @@ def refine_args(
args = _infer_prompt_source(args)
args = _check_model_args(parser, args)
args = _check_conditional_args(parser, args)
args = _check_image_input_args(parser, args)
args = _check_load_manager_args(args)
args = _set_artifact_paths(args)
elif args.subcommand == Subcommand.COMPARE.to_lowercase():
Expand Down
12 changes: 12 additions & 0 deletions src/c++/perf_analyzer/genai-perf/genai_perf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@
# Skip type checking to avoid mypy error
# Issue: https://github.com/python/mypy/issues/10632
import yaml # type: ignore
from PIL import Image


def encode_image(img: Image, format: str):
"""Encodes an image into base64 encoding."""
# Lazy import for vision related endpoints
import base64
from io import BytesIO

buffered = BytesIO()
img.save(buffered, format=format)
return base64.b64encode(buffered.getvalue()).decode("utf-8")


def remove_sse_prefix(msg: str) -> str:
Expand Down
40 changes: 39 additions & 1 deletion src/c++/perf_analyzer/genai-perf/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import pytest
from genai_perf import __version__, parser
from genai_perf.llm_inputs.llm_inputs import (
ImageFormat,
ModelSelectionStrategy,
OutputFormat,
PromptSource,
Expand All @@ -40,7 +41,7 @@

class TestCLIArguments:
# ================================================
# GENAI-PERF COMMAND
# PROFILE COMMAND
# ================================================
expected_help_output = (
"CLI to profile LLMs and Generative AI models with Perf Analyzer"
Expand Down Expand Up @@ -215,6 +216,23 @@ def test_help_version_arguments_output_and_exit(
["--synthetic-input-tokens-stddev", "7"],
{"synthetic_input_tokens_stddev": 7},
),
(
["--image-width-mean", "123"],
{"image_width_mean": 123},
),
(
["--image-width-stddev", "123"],
{"image_width_stddev": 123},
),
(
["--image-height-mean", "456"],
{"image_height_mean": 456},
),
(
["--image-height-stddev", "456"],
{"image_height_stddev": 456},
),
(["--image-format", "png"], {"image_format": ImageFormat.PNG}),
(["-v"], {"verbose": True}),
(["--verbose"], {"verbose": True}),
(["-u", "test_url"], {"u": "test_url"}),
Expand Down Expand Up @@ -732,6 +750,26 @@ def test_prompt_source_assertions(self, monkeypatch, mocker, capsys):
captured = capsys.readouterr()
assert expected_output in captured.err

@pytest.mark.parametrize(
"args",
[
# negative numbers
["--image-width-mean", "-123"],
["--image-width-stddev", "-34"],
["--image-height-mean", "-123"],
["--image-height-stddev", "-34"],
# zeros
["--image-width-mean", "0"],
["--image-height-mean", "0"],
],
)
def test_positive_image_input_args(self, monkeypatch, args):
combined_args = ["genai-perf", "profile", "-m", "test_model"] + args
monkeypatch.setattr("sys.argv", combined_args)

with pytest.raises(SystemExit) as excinfo:
parser.parse_args()

# ================================================
# COMPARE SUBCOMMAND
# ================================================
Expand Down
5 changes: 5 additions & 0 deletions src/c++/perf_analyzer/genai-perf/tests/test_json_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ def test_generate_json(self, monkeypatch) -> None:
"random_seed": 0,
"synthetic_input_tokens_mean": 550,
"synthetic_input_tokens_stddev": 0,
"image_width_mean": 100,
"image_width_stddev": 0,
"image_height_mean": 100,
"image_height_stddev": 0,
"image_format": "png",
"concurrency": 1,
"measurement_interval": 10000,
"request_rate": null,
Expand Down
1 change: 0 additions & 1 deletion src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
ModelSelectionStrategy,
OutputFormat,
PromptSource,
make_snowman_image,
)
from genai_perf.tokenizer import Tokenizer
from PIL import Image
Expand Down

0 comments on commit 8e5570e

Please sign in to comment.