diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/dataset_retriever.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/dataset_retriever.py index cd00ade9a..804365e1f 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/dataset_retriever.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/dataset_retriever.py @@ -59,6 +59,7 @@ def from_url(url: str, starting_index: int, length: int) -> List[Dict[str, Any]] ] return formatted_rows + # (TMA-2018) decouple output_format from this method @staticmethod def from_file(file_path: Path, output_format: OutputFormat) -> List[Dict[str, str]]: contents = DatasetRetriever._load_file_content(file_path) @@ -69,7 +70,7 @@ def from_file(file_path: Path, output_format: OutputFormat) -> List[Dict[str, st if output_format == OutputFormat.OPENAI_VISION: img_filename = content.get("image", "") - encoded_img = DatasetRetriever._encode_image_to_base64(img_filename) + encoded_img = DatasetRetriever._read_image_content(img_filename) data["image"] = encoded_img dataset.append(data) @@ -94,7 +95,7 @@ def _load_file_content(file_path: Path) -> List[Dict[str, str]]: return contents @staticmethod - def _encode_image_to_base64(filename: str) -> str: + def _read_image_content(filename: str) -> str: try: img = Image.open(filename) except: diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py index a02afe2ea..48e8afdde 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py @@ -102,12 +102,6 @@ def create_llm_inputs( random.seed(random_seed) if input_type == PromptSource.DATASET: - # (TMA-1990) support VLM input from public dataset - if output_format == OutputFormat.OPENAI_VISION: - raise GenAIPerfException( - f"{OutputFormat.OPENAI_VISION.to_lowercase()} currently " - "does not support dataset as input." - ) dataset = DatasetRetriever.from_url( cls.dataset_url_map[dataset_name], starting_index, length ) @@ -169,6 +163,7 @@ def validate_args( PromptSource.DATASET, ], OutputFormat.RANKINGS: [PromptSource.DATASET, PromptSource.SYNTHETIC], + OutputFormat.OPENAI_VISION: [PromptSource.DATASET], } if input_type in unsupported_combinations.get(output_format, []): diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py index 68b7a1edf..c3b4d202c 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py @@ -224,7 +224,7 @@ def test_get_input_file_with_multiple_prompts(self, mock_file, mock_exists): @patch("pathlib.Path.exists", return_value=True) @patch( - "genai_perf.llm_inputs.dataset_retriever.DatasetRetriever._encode_image_to_base64", + "genai_perf.llm_inputs.dataset_retriever.DatasetRetriever._read_image_content", return_value="data:image/png;base64,...", ) @patch(