Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-modal input from file for OpenAI Chat Completions #749

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,26 @@ def get_generic_dataset_json(
)
else:
if input_type == PromptSource.DATASET:
# (TMA-1990) support VLM input from public dataset
if output_format == OutputFormat.OPENAI_VISION:
raise GenAIPerfException(
f"{OutputFormat.OPENAI_VISION.to_lowercase()} currently "
"does not support dataset as input."
)
dataset = cls._get_input_dataset_from_url(
dataset_name, starting_index, length
)
generic_dataset_json = cls._convert_input_url_dataset_to_generic_json(
dataset
)
elif input_type == PromptSource.SYNTHETIC:
# (TMA-1989) support synthetic image generation for VLM input
if output_format == OutputFormat.OPENAI_VISION:
raise GenAIPerfException(
f"{OutputFormat.OPENAI_VISION.to_lowercase()} currently "
"does not support synthetic input."
)

synthetic_dataset = cls._get_input_dataset_from_synthetic(
tokenizer,
prompt_tokens_mean,
Expand All @@ -347,6 +360,9 @@ def get_generic_dataset_json(
elif input_type == PromptSource.FILE:
input_filename = cast(Path, input_filename)
input_file_dataset = cls._get_input_dataset_from_file(input_filename)
input_file_dataset = cls._encode_images_in_input_dataset(
input_file_dataset
)
generic_dataset_json = (
cls._convert_input_synthetic_or_file_dataset_to_generic_json(
input_file_dataset
Expand All @@ -355,10 +371,12 @@ def get_generic_dataset_json(
else:
raise GenAIPerfException("Input source is not recognized.")

# When the generic_dataset_json contains multi-modal data (e.g. images),
# convert the format of the content to OpenAI multi-modal format:
# see https://platform.openai.com/docs/guides/vision
if output_format == OutputFormat.OPENAI_VISION:
snowman_image = make_snowman_image()
generic_dataset_json = cls._add_images_to_generic_json(
generic_dataset_json, snowman_image
generic_dataset_json = cls._convert_to_openai_multi_modal_content(
generic_dataset_json
)

return generic_dataset_json
Expand Down Expand Up @@ -549,29 +567,37 @@ def _add_rows_to_generic_json(
@classmethod
def _get_input_dataset_from_file(cls, input_filename: Path) -> Dict:
"""
Reads the input prompts from a JSONL file and converts them into the required dataset format.
Reads the input prompts and images from a JSONL file and converts them
into the required dataset format.

Parameters
----------
input_filename : Path
The path to the input file containing the prompts in JSONL format.
The path to the input file containing the prompts and/or images in
JSONL format.

Returns
-------
Dict
The dataset in the required format with the prompts read from the file.
The dataset in the required format with the prompts and/or images
read from the file.
"""
cls.verify_file(input_filename)
input_file_prompts = cls._get_prompts_from_input_file(input_filename)
prompts, images = cls._get_prompts_from_input_file(input_filename)
dataset_json: Dict[str, Any] = {}
dataset_json["features"] = [{"name": "text_input"}]
dataset_json["rows"] = [
{"row": {"text_input": prompt}} for prompt in input_file_prompts
]
dataset_json["rows"] = []
for prompt, image in zip(prompts, images):
content = {"text_input": prompt}
content.update({"image": image} if image else {})
dataset_json["rows"].append({"row": content})

return dataset_json

@classmethod
def _get_prompts_from_input_file(cls, input_filename: Path) -> List[str]:
def _get_prompts_from_input_file(
cls, input_filename: Path
) -> Tuple[List[str], List[str]]:
"""
Reads the input prompts from a JSONL file and returns a list of prompts.

Expand All @@ -582,43 +608,57 @@ def _get_prompts_from_input_file(cls, input_filename: Path) -> List[str]:

Returns
-------
List[str]
A list of prompts read from the file.
Tuple[List[str], List[str]]
A list of prompts and images read from the file.
"""
prompts = []
images = []
with open(input_filename, mode="r", newline=None) as file:
for line in file:
if line.strip():
prompts.append(json.loads(line).get("text_input", "").strip())
return prompts
images.append(json.loads(line).get("image", "").strip())
return prompts, images

@classmethod
def verify_file(cls, input_filename: Path) -> None:
if not input_filename.exists():
raise FileNotFoundError(f"The file '{input_filename}' does not exist.")

@classmethod
def _add_images_to_generic_json(
cls, generic_dataset_json: Dict[str, List[Dict]], img: Image
def _convert_to_openai_multi_modal_content(
cls, generic_dataset_json: Dict[str, List[Dict]]
) -> Dict[str, List[Dict]]:
# (TMA-1985) Support multiple image formats
img_format = ImageFormat.PNG
img_base64 = cls._encode_image(img, img_format)
"""
Converts to multi-modal content format of OpenAI Chat Completions API.
"""
for row in generic_dataset_json["rows"]:
if isinstance(row["text_input"], str):
if row["image"]:
row["text_input"] = [
{
"type": "text",
"text": row["text_input"],
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{img_base64}"},
"image_url": {"url": row["image"]},
},
]

return generic_dataset_json

@classmethod
def _encode_images_in_input_dataset(cls, input_file_dataset: Dict) -> Dict:
for row in input_file_dataset["rows"]:
filename = row["row"].get("image")
if filename:
img = Image.open(filename)
# (TMA-1985) Support multiple image formats
img_base64 = cls._encode_image(img, ImageFormat.PNG)
row["row"]["image"] = f"data:image/png;base64,{img_base64}"

return input_file_dataset

@classmethod
def _encode_image(cls, img: Image, format=ImageFormat.PNG):
"""Encodes an image into base64 encoding."""
Expand Down
42 changes: 35 additions & 7 deletions src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import random
import statistics
from collections import namedtuple
from pathlib import Path
from unittest.mock import mock_open, patch

Expand All @@ -32,6 +33,7 @@
make_snowman_image,
)
from genai_perf.tokenizer import Tokenizer
from PIL import Image

mocked_openorca_data = {
"features": [
Expand Down Expand Up @@ -555,14 +557,12 @@ def test_llm_inputs_with_defaults(self, default_configured_url):
def test_add_image_inputs_openai_vision(self) -> None:
generic_json = {
"rows": [
{"text_input": "test input one"},
{"text_input": "test input two"},
{"text_input": "test input one", "image": "test_image1"},
{"text_input": "test input two", "image": "test_image2"},
]
}
img = make_snowman_image()
encoded_img = LlmInputs._encode_image(img)

generic_json = LlmInputs._add_images_to_generic_json(generic_json, img)
generic_json = LlmInputs._convert_to_openai_multi_modal_content(generic_json)

row1 = generic_json["rows"][0]["text_input"]
assert row1 == [
Expand All @@ -572,7 +572,7 @@ def test_add_image_inputs_openai_vision(self) -> None:
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encoded_img}"},
"image_url": {"url": "test_image1"},
},
]

Expand All @@ -584,7 +584,7 @@ def test_add_image_inputs_openai_vision(self) -> None:
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encoded_img}"},
"image_url": {"url": "test_image2"},
},
]

Expand Down Expand Up @@ -725,6 +725,34 @@ def test_get_input_file_with_multiple_prompts(self, mock_file, mock_exists):
for i, prompt in enumerate(expected_prompts):
assert dataset["rows"][i]["row"]["text_input"] == prompt

@patch("pathlib.Path.exists", return_value=True)
@patch("PIL.Image.open", return_value=Image.new("RGB", (10, 10)))
@patch(
"builtins.open",
new_callable=mock_open,
read_data=(
'{"text_input": "prompt1", "image": "image1.png"}\n'
'{"text_input": "prompt2", "image": "image2.png"}\n'
'{"text_input": "prompt3", "image": "image3.png"}\n'
),
)
def test_get_input_file_with_multi_modal_data(
self, mock_exists, mock_image, mock_file
):
Data = namedtuple("Data", ["text_input", "image"])
expected_data = [
Data(text_input="prompt1", image="image1.png"),
Data(text_input="prompt2", image="image2.png"),
Data(text_input="prompt3", image="image3.png"),
]
dataset = LlmInputs._get_input_dataset_from_file(Path("somefile.txt"))

assert dataset is not None
assert len(dataset["rows"]) == len(expected_data)
for i, data in enumerate(expected_data):
assert dataset["rows"][i]["row"]["text_input"] == data.text_input
assert dataset["rows"][i]["row"]["image"] == data.image

@pytest.mark.parametrize(
"seed, model_name_list, index,model_selection_strategy,expected_model",
[
Expand Down
Loading