Skip to content

Commit

Permalink
Draft client-side batching
Browse files Browse the repository at this point in the history
Fix client-side batching

Remove unused file

Remove unused file

Add batching and non-batching test

Run pre-commit hooks

Fix tests, remove unneeded file

Remove commented out code

Remove prints

Remove unused import
  • Loading branch information
dyastremsky committed Aug 21, 2024
1 parent 5779fb7 commit 11a3226
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 56 deletions.
183 changes: 127 additions & 56 deletions genai-perf/genai_perf/llm_inputs/llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def create_llm_inputs(
image_format:
The compression format of the images.
batch_size:
The number of inputs per request (currently only used for the embeddings and rankings endpoints)
The number of inputs per request (currently only used for the embeddings, image retrieval, and rankings endpoints)
Required Synthetic Prompt Generation Parameters
-----------------------------------------------
Expand Down Expand Up @@ -287,7 +287,7 @@ def get_generic_dataset_json(
image_format:
The compression format of the images.
batch_size:
The number of inputs per request (currently only used for the embeddings and rankings endpoints)
The number of inputs per request (currently only used for the embeddings, image retrieval, and rankings endpoints)
input_filename:
The path to the input file containing the prompts in JSONL format.
Returns
Expand Down Expand Up @@ -333,7 +333,9 @@ def get_generic_dataset_json(
f"{OutputFormat.IMAGE_RETRIEVAL.to_lowercase()} only supports a file as input."
)
input_filename = cast(Path, input_filename)
input_file_dataset = cls._get_input_dataset_from_file(input_filename)
input_file_dataset = cls._get_input_dataset_from_file(
input_filename, batch_size, num_of_output_prompts
)
input_file_dataset = cls._encode_images_in_input_dataset(input_file_dataset)
generic_dataset_json = (
cls._convert_input_synthetic_or_file_dataset_to_generic_json(
Expand Down Expand Up @@ -591,7 +593,12 @@ def _add_rows_to_generic_json(
return generic_input_json

@classmethod
def _get_input_dataset_from_file(cls, input_filename: Path) -> Dict:
def _get_input_dataset_from_file(
cls,
input_filename: Path,
batch_size: int = DEFAULT_BATCH_SIZE,
num_prompts: int = -1,
) -> Dict:
"""
Reads the input prompts and images from a JSONL file and converts them
into the required dataset format.
Expand All @@ -601,6 +608,10 @@ def _get_input_dataset_from_file(cls, input_filename: Path) -> Dict:
input_filename : Path
The path to the input file containing the prompts and/or images in
JSONL format.
batch_size : int
The number of inputs per request (currently only used for the embeddings, image retrieval, and rankings endpoints)
num_prompts : int
The number of prompts to generate. Used when batch_size is provided.
Returns
-------
Expand All @@ -610,16 +621,38 @@ def _get_input_dataset_from_file(cls, input_filename: Path) -> Dict:
"""
cls.verify_file(input_filename)
prompts, images = cls._get_prompts_from_input_file(input_filename)
if batch_size > len(prompts):
raise ValueError(
"Batch size cannot be larger than the number of available texts"
)
dataset_json: Dict[str, Any] = {}
dataset_json["features"] = [{"name": "text_input"}]
dataset_json["rows"] = []
for prompt, image in zip(prompts, images):
content = {}
if prompt is not None:
content["text_input"] = prompt
if image is not None:
content["image"] = image
dataset_json["rows"].append({"row": content})

if batch_size == LlmInputs.DEFAULT_BATCH_SIZE:
for prompt, image in zip(prompts, images):
content = {}
if prompt is not None:
content["text_input"] = prompt
if image is not None:
content["image"] = image
dataset_json["rows"].append({"row": content})
else:
for _ in range(num_prompts):
content_array = []
sampled_indices = random.sample(range(len(prompts)), batch_size)
sampled_texts_images = [
(prompts[i], images[i]) for i in sampled_indices
]

for prompt, image in sampled_texts_images:
content = {}
if prompt is not None:
content["text_input"] = prompt
if image is not None:
content["image"] = image
content_array.append(content)
dataset_json["rows"].append({"row": content_array})

return dataset_json

Expand Down Expand Up @@ -660,21 +693,40 @@ def verify_file(cls, input_filename: Path) -> None:
@classmethod
def _encode_images_in_input_dataset(cls, input_file_dataset: Dict) -> Dict:
for row in input_file_dataset["rows"]:
filename = row["row"].get("image")
if filename:
img = Image.open(filename)
if img.format.lower() not in utils.get_enum_names(ImageFormat):
raise GenAIPerfException(
f"Unsupported image format '{img.format}' of "
f"the image '{filename}'."
)

img_base64 = utils.encode_image(img, img.format)
payload = f"data:image/{img.format.lower()};base64,{img_base64}"
row["row"]["image"] = payload
if isinstance(row["row"], list):
for content in row["row"]:
filename = content.get("image")
if filename:
payload = cls._encode_image(filename)
content["image"] = payload
else:
filename = row["row"].get("image")
if filename:
payload = cls._encode_image(filename)
row["row"]["image"] = payload

return input_file_dataset

@classmethod
def _encode_image(cls, filename: str) -> str:
img = Image.open(filename)
if img is None:
raise GenAIPerfException(f"Failed to open image '{filename}'.")
if img.format is None:
raise GenAIPerfException(
f"Failed to determine image format of '{filename}'."
)

if img.format.lower() not in utils.get_enum_names(ImageFormat):
raise GenAIPerfException(
f"Unsupported image format '{img.format}' of "
f"the image '{filename}'."
)

img_base64 = utils.encode_image(img, img.format)
payload = f"data:image/{img.format.lower()};base64,{img_base64}"
return payload

@classmethod
def _convert_generic_json_to_output_format(
cls,
Expand Down Expand Up @@ -1097,18 +1149,18 @@ def _populate_openai_chat_completions_output_json(
)
openai_json: Dict = {"payload": [{"messages": []}]}

# multi-modal content format
if "image" in entry:
contents: List[Dict] = cls._extract_chat_contents(entry)
openai_json = {
"payload": [{"messages": [{"role": "user", "content": contents}]}]
}
else:
for header, content in entry.items():
message = cls._create_new_openai_chat_completions_message(
header, system_role_headers, user_role_headers, content
# Check if the entry is a list (batched entries) or a single entry
if isinstance(entry, list):
for item in entry:
cls._process_row_content(
item, system_role_headers, user_role_headers, openai_json
)
cls._add_message_to_json(openai_json, message)
elif isinstance(entry, dict):
cls._process_row_content(
entry, system_role_headers, user_role_headers, openai_json
)
else:
raise GenAIPerfException(f"Unexpected data type in rows: {type(entry)}")

cls._add_optional_tags_to_openai_json(
openai_json,
Expand All @@ -1124,33 +1176,52 @@ def _populate_openai_chat_completions_output_json(
return pa_json

@classmethod
def _extract_chat_contents(cls, entry: Dict) -> List[Dict]:
contents = []
for content_type, content in entry.items():
if content_type == "text_input":
contents.append(
{
"type": "text",
"text": content,
}
)
elif content_type == "image":
contents.append(
{
"type": "image_url",
"image_url": {
"url": content,
},
}
)
def _process_row_content(
cls,
entry: Dict,
system_role_headers: List[str],
user_role_headers: List[str],
openai_json: Dict,
) -> None:
if "image" in entry:
contents = cls._extract_chat_contents(entry)
if openai_json["payload"][0]["messages"]:
openai_json["payload"][0]["messages"][0]["content"].extend(contents)
else:
raise GenAIPerfException(
"Failed to construct OpenAI chat completions message "
f"contents. Unknown content type: '{content_type}'."
openai_json["payload"][0]["messages"].append(
{"role": "user", "content": contents}
)
else:
for header, content in entry.items():
message = cls._create_new_openai_chat_completions_message(
header, system_role_headers, user_role_headers, content
)
cls._add_message_to_json(openai_json, message)

@classmethod
def _extract_chat_contents(cls, entry: Dict) -> List[Dict]:
contents: List = []
if isinstance(entry, list):
for item in entry:
for content_type, content in item.items():
cls._add_content(contents, content_type, content)
else:
for content_type, content in entry.items():
cls._add_content(contents, content_type, content)
return contents

@classmethod
def _add_content(cls, contents: List[Dict], content_type: str, content: str):
if content_type == "text_input":
contents.append({"type": "text", "text": content})
elif content_type == "image":
contents.append({"type": "image_url", "image_url": {"url": content}})
else:
raise GenAIPerfException(
"Failed to construct OpenAI chat completions message "
f"contents. Unknown content type: '{content_type}'."
)

@classmethod
def _populate_openai_completions_output_json(
cls,
Expand Down
131 changes: 131 additions & 0 deletions genai-perf/tests/test_llm_inputs_image_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from unittest.mock import patch

from genai_perf.llm_inputs.llm_inputs import LlmInputs, OutputFormat, PromptSource


class TestLlmInputsImageRetrieval:

@patch(
"genai_perf.llm_inputs.llm_inputs.LlmInputs._encode_image",
return_value="data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/",
)
@patch("genai_perf.llm_inputs.llm_inputs.LlmInputs._get_input_dataset_from_file")
def test_image_retrieval(self, mock_get_input, mock_encode_image):
mock_get_input.return_value = {
"features": [{"name": "text_input"}],
"rows": [
{"row": [{"image": "genai_perf/llm_inputs/source_images/image1.jpg"}]}
],
}

pa_json = LlmInputs.create_llm_inputs(
input_type=PromptSource.FILE,
output_format=OutputFormat.IMAGE_RETRIEVAL,
input_filename=Path("dummy.jsonl"),
model_name=["test_model"],
add_model_name=True,
)

expected_json = {
"data": [
{
"payload": [
{
"model": "test_model",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/"
},
}
],
}
],
}
]
}
]
}

assert pa_json == expected_json

@patch("genai_perf.llm_inputs.llm_inputs.LlmInputs._get_input_dataset_from_file")
@patch("genai_perf.llm_inputs.llm_inputs.LlmInputs._encode_image")
def test_image_retrieval_batched(self, mock_encode_image, mock_get_input):
mock_get_input.return_value = {
"features": [{"name": "text_input"}],
"rows": [
{
"row": [
{"image": "genai_perf/llm_inputs/source_images/image1.jpg"},
{"image": "genai_perf/llm_inputs/source_images/image2.jpg"},
]
}
],
}
mock_encode_image.side_effect = [
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/",
"data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/",
]

pa_json = LlmInputs.create_llm_inputs(
input_type=PromptSource.FILE,
output_format=OutputFormat.IMAGE_RETRIEVAL,
input_filename=Path("dummy.jsonl"),
batch_size=2,
num_of_output_prompts=1,
model_name=["test_model"],
add_model_name=True,
)

expected_json = {
"data": [
{
"payload": [
{
"model": "test_model",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/"
},
},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/"
},
},
],
}
],
}
]
}
]
}

assert pa_json == expected_json

0 comments on commit 11a3226

Please sign in to comment.