Skip to content

Commit

Permalink
Support multi-modal input from file for OpenAI Chat Completions (#749)
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-hwoo authored Jul 11, 2024
1 parent 5fd9004 commit 916fe91
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,26 @@ def get_generic_dataset_json(
)
else:
if input_type == PromptSource.DATASET:
# (TMA-1990) support VLM input from public dataset
if output_format == OutputFormat.OPENAI_VISION:
raise GenAIPerfException(
f"{OutputFormat.OPENAI_VISION.to_lowercase()} currently "
"does not support dataset as input."
)
dataset = cls._get_input_dataset_from_url(
dataset_name, starting_index, length
)
generic_dataset_json = cls._convert_input_url_dataset_to_generic_json(
dataset
)
elif input_type == PromptSource.SYNTHETIC:
# (TMA-1989) support synthetic image generation for VLM input
if output_format == OutputFormat.OPENAI_VISION:
raise GenAIPerfException(
f"{OutputFormat.OPENAI_VISION.to_lowercase()} currently "
"does not support synthetic input."
)

synthetic_dataset = cls._get_input_dataset_from_synthetic(
tokenizer,
prompt_tokens_mean,
Expand All @@ -347,6 +360,9 @@ def get_generic_dataset_json(
elif input_type == PromptSource.FILE:
input_filename = cast(Path, input_filename)
input_file_dataset = cls._get_input_dataset_from_file(input_filename)
input_file_dataset = cls._encode_images_in_input_dataset(
input_file_dataset
)
generic_dataset_json = (
cls._convert_input_synthetic_or_file_dataset_to_generic_json(
input_file_dataset
Expand All @@ -355,10 +371,12 @@ def get_generic_dataset_json(
else:
raise GenAIPerfException("Input source is not recognized.")

# When the generic_dataset_json contains multi-modal data (e.g. images),
# convert the format of the content to OpenAI multi-modal format:
# see https://platform.openai.com/docs/guides/vision
if output_format == OutputFormat.OPENAI_VISION:
snowman_image = make_snowman_image()
generic_dataset_json = cls._add_images_to_generic_json(
generic_dataset_json, snowman_image
generic_dataset_json = cls._convert_to_openai_multi_modal_content(
generic_dataset_json
)

return generic_dataset_json
Expand Down Expand Up @@ -549,29 +567,37 @@ def _add_rows_to_generic_json(
@classmethod
def _get_input_dataset_from_file(cls, input_filename: Path) -> Dict:
"""
Reads the input prompts from a JSONL file and converts them into the required dataset format.
Reads the input prompts and images from a JSONL file and converts them
into the required dataset format.
Parameters
----------
input_filename : Path
The path to the input file containing the prompts in JSONL format.
The path to the input file containing the prompts and/or images in
JSONL format.
Returns
-------
Dict
The dataset in the required format with the prompts read from the file.
The dataset in the required format with the prompts and/or images
read from the file.
"""
cls.verify_file(input_filename)
input_file_prompts = cls._get_prompts_from_input_file(input_filename)
prompts, images = cls._get_prompts_from_input_file(input_filename)
dataset_json: Dict[str, Any] = {}
dataset_json["features"] = [{"name": "text_input"}]
dataset_json["rows"] = [
{"row": {"text_input": prompt}} for prompt in input_file_prompts
]
dataset_json["rows"] = []
for prompt, image in zip(prompts, images):
content = {"text_input": prompt}
content.update({"image": image} if image else {})
dataset_json["rows"].append({"row": content})

return dataset_json

@classmethod
def _get_prompts_from_input_file(cls, input_filename: Path) -> List[str]:
def _get_prompts_from_input_file(
cls, input_filename: Path
) -> Tuple[List[str], List[str]]:
"""
Reads the input prompts from a JSONL file and returns a list of prompts.
Expand All @@ -582,43 +608,57 @@ def _get_prompts_from_input_file(cls, input_filename: Path) -> List[str]:
Returns
-------
List[str]
A list of prompts read from the file.
Tuple[List[str], List[str]]
A list of prompts and images read from the file.
"""
prompts = []
images = []
with open(input_filename, mode="r", newline=None) as file:
for line in file:
if line.strip():
prompts.append(json.loads(line).get("text_input", "").strip())
return prompts
images.append(json.loads(line).get("image", "").strip())
return prompts, images

@classmethod
def verify_file(cls, input_filename: Path) -> None:
if not input_filename.exists():
raise FileNotFoundError(f"The file '{input_filename}' does not exist.")

@classmethod
def _add_images_to_generic_json(
cls, generic_dataset_json: Dict[str, List[Dict]], img: Image
def _convert_to_openai_multi_modal_content(
cls, generic_dataset_json: Dict[str, List[Dict]]
) -> Dict[str, List[Dict]]:
# (TMA-1985) Support multiple image formats
img_format = ImageFormat.PNG
img_base64 = cls._encode_image(img, img_format)
"""
Converts to multi-modal content format of OpenAI Chat Completions API.
"""
for row in generic_dataset_json["rows"]:
if isinstance(row["text_input"], str):
if row["image"]:
row["text_input"] = [
{
"type": "text",
"text": row["text_input"],
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{img_base64}"},
"image_url": {"url": row["image"]},
},
]

return generic_dataset_json

@classmethod
def _encode_images_in_input_dataset(cls, input_file_dataset: Dict) -> Dict:
for row in input_file_dataset["rows"]:
filename = row["row"].get("image")
if filename:
img = Image.open(filename)
# (TMA-1985) Support multiple image formats
img_base64 = cls._encode_image(img, ImageFormat.PNG)
row["row"]["image"] = f"data:image/png;base64,{img_base64}"

return input_file_dataset

@classmethod
def _encode_image(cls, img: Image, format=ImageFormat.PNG):
"""Encodes an image into base64 encoding."""
Expand Down
42 changes: 35 additions & 7 deletions src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import random
import statistics
from collections import namedtuple
from pathlib import Path
from unittest.mock import mock_open, patch

Expand All @@ -32,6 +33,7 @@
make_snowman_image,
)
from genai_perf.tokenizer import Tokenizer
from PIL import Image

mocked_openorca_data = {
"features": [
Expand Down Expand Up @@ -555,14 +557,12 @@ def test_llm_inputs_with_defaults(self, default_configured_url):
def test_add_image_inputs_openai_vision(self) -> None:
generic_json = {
"rows": [
{"text_input": "test input one"},
{"text_input": "test input two"},
{"text_input": "test input one", "image": "test_image1"},
{"text_input": "test input two", "image": "test_image2"},
]
}
img = make_snowman_image()
encoded_img = LlmInputs._encode_image(img)

generic_json = LlmInputs._add_images_to_generic_json(generic_json, img)
generic_json = LlmInputs._convert_to_openai_multi_modal_content(generic_json)

row1 = generic_json["rows"][0]["text_input"]
assert row1 == [
Expand All @@ -572,7 +572,7 @@ def test_add_image_inputs_openai_vision(self) -> None:
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encoded_img}"},
"image_url": {"url": "test_image1"},
},
]

Expand All @@ -584,7 +584,7 @@ def test_add_image_inputs_openai_vision(self) -> None:
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encoded_img}"},
"image_url": {"url": "test_image2"},
},
]

Expand Down Expand Up @@ -725,6 +725,34 @@ def test_get_input_file_with_multiple_prompts(self, mock_file, mock_exists):
for i, prompt in enumerate(expected_prompts):
assert dataset["rows"][i]["row"]["text_input"] == prompt

@patch("pathlib.Path.exists", return_value=True)
@patch("PIL.Image.open", return_value=Image.new("RGB", (10, 10)))
@patch(
"builtins.open",
new_callable=mock_open,
read_data=(
'{"text_input": "prompt1", "image": "image1.png"}\n'
'{"text_input": "prompt2", "image": "image2.png"}\n'
'{"text_input": "prompt3", "image": "image3.png"}\n'
),
)
def test_get_input_file_with_multi_modal_data(
self, mock_exists, mock_image, mock_file
):
Data = namedtuple("Data", ["text_input", "image"])
expected_data = [
Data(text_input="prompt1", image="image1.png"),
Data(text_input="prompt2", image="image2.png"),
Data(text_input="prompt3", image="image3.png"),
]
dataset = LlmInputs._get_input_dataset_from_file(Path("somefile.txt"))

assert dataset is not None
assert len(dataset["rows"]) == len(expected_data)
for i, data in enumerate(expected_data):
assert dataset["rows"][i]["row"]["text_input"] == data.text_input
assert dataset["rows"][i]["row"]["image"] == data.image

@pytest.mark.parametrize(
"seed, model_name_list, index,model_selection_strategy,expected_model",
[
Expand Down

0 comments on commit 916fe91

Please sign in to comment.