From 4048c7078df37882f4bac8b5e0bb9cf918cacc5e Mon Sep 17 00:00:00 2001 From: Hyunjae Woo <107147848+nv-hwoo@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:31:24 -0700 Subject: [PATCH] Add file input retriever class for reading from file source (#86) * add file input retriever * move input source check to inputs.py * create unittests for file input retriever --- .../genai_perf/inputs/file_input_retriever.py | 185 ++++++++++++++ .../inputs/input_retriever_factory.py | 192 ++------------- genai-perf/genai_perf/inputs/inputs.py | 19 +- genai-perf/tests/test_file_input_retriever.py | 227 ++++++++++++++++++ genai-perf/tests/test_input_embeddings.py | 9 +- .../tests/test_input_image_retrieval.py | 4 +- genai-perf/tests/test_input_rankings.py | 68 ------ .../tests/test_input_retriever_factory.py | 93 +------ 8 files changed, 453 insertions(+), 344 deletions(-) create mode 100644 genai-perf/genai_perf/inputs/file_input_retriever.py create mode 100644 genai-perf/tests/test_file_input_retriever.py diff --git a/genai-perf/genai_perf/inputs/file_input_retriever.py b/genai-perf/genai_perf/inputs/file_input_retriever.py new file mode 100644 index 00000000..940b8398 --- /dev/null +++ b/genai-perf/genai_perf/inputs/file_input_retriever.py @@ -0,0 +1,185 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import random +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from genai_perf.inputs.input_constants import DEFAULT_BATCH_SIZE, OutputFormat +from genai_perf.inputs.inputs_config import InputsConfig +from genai_perf.utils import load_json_str + + +class FileInputRetriever: + """ + A input retriever class that handles input data provided by the user through + file and directories. + """ + + def __init__(self, config: InputsConfig) -> None: + self.config = config + + # TODO: match return type to retriever interface + def retrieve_data(self) -> Dict[str, Any]: + if self.config.output_format == OutputFormat.OPENAI_EMBEDDINGS: + return self._get_input_dataset_from_embeddings_file() + elif self.config.output_format == OutputFormat.RANKINGS: + queries_filename = self.config.input_filename / "queries.jsonl" + passages_filename = self.config.input_filename / "passages.jsonl" + return self._get_input_dataset_from_rankings_files( + queries_filename, passages_filename + ) + elif self.config.output_format == OutputFormat.IMAGE_RETRIEVAL: + return self._get_input_dataset_from_file() + else: + return self._get_input_dataset_from_file() + + def _get_input_dataset_from_embeddings_file(self) -> Dict[str, Any]: + with open(self.config.input_filename, "r") as file: + file_content = [load_json_str(line) for line in file] + + texts = [item["text"] for item in file_content] + + if self.config.batch_size > len(texts): + raise ValueError( + "Batch size cannot be larger than the number of available texts" + ) + + dataset_json: Dict[str, Any] = {} + dataset_json["features"] = [{"name": "input"}] + dataset_json["rows"] = [] + + for _ in range(self.config.num_prompts): + sampled_texts = random.sample(texts, self.config.batch_size) + dataset_json["rows"].append({"row": {"payload": {"input": sampled_texts}}}) + + return dataset_json + + def _get_input_dataset_from_rankings_files( + self, + queries_filename: Path, + passages_filename: Path, + ) -> Dict[str, Any]: + + with open(queries_filename, "r") as file: + queries_content = [load_json_str(line) for line in file] + queries_texts = [item for item in queries_content] + + with open(passages_filename, "r") as file: + passages_content = [load_json_str(line) for line in file] + passages_texts = [item for item in passages_content] + + if self.config.batch_size > len(passages_texts): + raise ValueError( + "Batch size cannot be larger than the number of available passages" + ) + + dataset_json: Dict[str, Any] = {} + dataset_json["features"] = [{"name": "input"}] + dataset_json["rows"] = [] + + for _ in range(self.config.num_prompts): + sampled_texts = random.sample(passages_texts, self.config.batch_size) + query_sample = random.choice(queries_texts) + entry_dict: Dict = {} + entry_dict["query"] = query_sample + entry_dict["passages"] = sampled_texts + dataset_json["rows"].append({"row": {"payload": entry_dict}}) + return dataset_json + + def _get_input_dataset_from_file(self) -> Dict[str, Any]: + """ + Returns + ------- + Dict + The dataset in the required format with the prompts and/or images + read from the file. + """ + self._verify_file() + prompts, images = self._get_prompts_from_input_file() + if self.config.batch_size > len(prompts): + raise ValueError( + "Batch size cannot be larger than the number of available texts" + ) + dataset_json: Dict[str, Any] = {} + dataset_json["features"] = [{"name": "text_input"}] + dataset_json["rows"] = [] + + if self.config.batch_size == DEFAULT_BATCH_SIZE: + for prompt, image in zip(prompts, images): + content = {} + if prompt is not None: + content["text_input"] = prompt + if image is not None: + content["image"] = image + dataset_json["rows"].append({"row": content}) + else: + for _ in range(self.config.num_prompts): + content_array = [] + sampled_indices = random.sample( + range(len(prompts)), self.config.batch_size + ) + sampled_texts_images = [ + (prompts[i], images[i]) for i in sampled_indices + ] + + for prompt, image in sampled_texts_images: + content = {} + if prompt is not None: + content["text_input"] = prompt + if image is not None: + content["image"] = image + content_array.append(content) + dataset_json["rows"].append({"row": content_array}) + + return dataset_json + + def _verify_file(self) -> None: + if not self.config.input_filename.exists(): + raise FileNotFoundError( + f"The file '{self.config.input_filename}' does not exist." + ) + + def _get_prompts_from_input_file(self) -> Tuple[List[str], List[str]]: + """ + Reads the input prompts from a JSONL file and returns a list of prompts. + + Returns + ------- + Tuple[List[str], List[str]] + A list of prompts and images read from the file. + """ + prompts = [] + images = [] + with open(self.config.input_filename, mode="r", newline=None) as file: + for line in file: + if line.strip(): + # None if not provided + prompt = load_json_str(line).get("text_input") + image = load_json_str(line).get("image") + prompts.append(prompt.strip() if prompt else prompt) + images.append(image.strip() if image else image) + return prompts, images diff --git a/genai-perf/genai_perf/inputs/input_retriever_factory.py b/genai-perf/genai_perf/inputs/input_retriever_factory.py index 4da5fbf4..618f47c4 100644 --- a/genai-perf/genai_perf/inputs/input_retriever_factory.py +++ b/genai-perf/genai_perf/inputs/input_retriever_factory.py @@ -24,15 +24,13 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import random -from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List import requests from genai_perf import utils from genai_perf.exceptions import GenAIPerfException +from genai_perf.inputs.file_input_retriever import FileInputRetriever from genai_perf.inputs.input_constants import ( - DEFAULT_BATCH_SIZE, OutputFormat, PromptSource, dataset_url_map, @@ -43,7 +41,6 @@ SyntheticImageGenerator, ) from genai_perf.inputs.synthetic_prompt_generator import SyntheticPromptGenerator -from genai_perf.utils import load_json_str from PIL import Image from requests import Response @@ -62,44 +59,21 @@ def get_input_data(self) -> Dict: The generic dataset JSON """ - if self.config.output_format == OutputFormat.OPENAI_EMBEDDINGS: - if self.config.input_type != PromptSource.FILE: - raise GenAIPerfException( - f"{OutputFormat.OPENAI_EMBEDDINGS.to_lowercase()} only supports a file as input." - ) - input_file_dataset = self._get_input_dataset_from_embeddings_file() - generic_dataset_json = ( - self._convert_input_synthetic_or_file_dataset_to_generic_json( - input_file_dataset - ) - ) - elif self.config.output_format == OutputFormat.RANKINGS: - if self.config.input_type != PromptSource.FILE: - raise GenAIPerfException( - f"{OutputFormat.RANKINGS.to_lowercase()} only supports a directory as input." - ) - queries_filename = self.config.input_filename / "queries.jsonl" - passages_filename = self.config.input_filename / "passages.jsonl" - input_file_dataset = self._get_input_dataset_from_rankings_files( - queries_filename, passages_filename - ) - generic_dataset_json = ( - self._convert_input_synthetic_or_file_dataset_to_generic_json( - input_file_dataset - ) - ) - elif self.config.output_format == OutputFormat.IMAGE_RETRIEVAL: - if self.config.input_type != PromptSource.FILE: - raise GenAIPerfException( - f"{OutputFormat.IMAGE_RETRIEVAL.to_lowercase()} only supports a file as input." - ) - input_file_dataset = self._get_input_dataset_from_file() - input_file_dataset = self._encode_images_in_input_dataset( - input_file_dataset - ) + if self.config.output_format in [ + OutputFormat.OPENAI_EMBEDDINGS, + OutputFormat.RANKINGS, + OutputFormat.IMAGE_RETRIEVAL, + ]: + # TODO: remove once the factory fully integrates retrievers + file_retriever = FileInputRetriever(self.config) + input_data = file_retriever.retrieve_data() + + if self.config.output_format == OutputFormat.IMAGE_RETRIEVAL: + input_data = self._encode_images_in_input_dataset(input_data) + generic_dataset_json = ( self._convert_input_synthetic_or_file_dataset_to_generic_json( - input_file_dataset + input_data ) ) else: @@ -122,10 +96,11 @@ def get_input_data(self) -> Dict: ) ) elif self.config.input_type == PromptSource.FILE: - input_file_dataset = self._get_input_dataset_from_file() - input_file_dataset = self._encode_images_in_input_dataset( - input_file_dataset - ) + # TODO: remove once the factory fully integrates retrievers + file_retriever = FileInputRetriever(self.config) + input_data = file_retriever.retrieve_data() + + input_file_dataset = self._encode_images_in_input_dataset(input_data) generic_dataset_json = ( self._convert_input_synthetic_or_file_dataset_to_generic_json( input_file_dataset @@ -136,27 +111,6 @@ def get_input_data(self) -> Dict: return generic_dataset_json - def _get_input_dataset_from_embeddings_file(self) -> Dict[str, Any]: - with open(self.config.input_filename, "r") as file: - file_content = [load_json_str(line) for line in file] - - texts = [item["text"] for item in file_content] - - if self.config.batch_size > len(texts): - raise ValueError( - "Batch size cannot be larger than the number of available texts" - ) - - dataset_json: Dict[str, Any] = {} - dataset_json["features"] = [{"name": "input"}] - dataset_json["rows"] = [] - - for _ in range(self.config.num_prompts): - sampled_texts = random.sample(texts, self.config.batch_size) - dataset_json["rows"].append({"row": {"payload": {"input": sampled_texts}}}) - - return dataset_json - def _convert_input_synthetic_or_file_dataset_to_generic_json( self, dataset: Dict ) -> Dict[str, List[Dict]]: @@ -164,85 +118,6 @@ def _convert_input_synthetic_or_file_dataset_to_generic_json( return generic_dataset_json - def _get_input_dataset_from_rankings_files( - self, - queries_filename: Path, - passages_filename: Path, - ) -> Dict[str, Any]: - - with open(queries_filename, "r") as file: - queries_content = [load_json_str(line) for line in file] - queries_texts = [item for item in queries_content] - - with open(passages_filename, "r") as file: - passages_content = [load_json_str(line) for line in file] - passages_texts = [item for item in passages_content] - - if self.config.batch_size > len(passages_texts): - raise ValueError( - "Batch size cannot be larger than the number of available passages" - ) - - dataset_json: Dict[str, Any] = {} - dataset_json["features"] = [{"name": "input"}] - dataset_json["rows"] = [] - - for _ in range(self.config.num_prompts): - sampled_texts = random.sample(passages_texts, self.config.batch_size) - query_sample = random.choice(queries_texts) - entry_dict: Dict = {} - entry_dict["query"] = query_sample - entry_dict["passages"] = sampled_texts - dataset_json["rows"].append({"row": {"payload": entry_dict}}) - return dataset_json - - def _get_input_dataset_from_file(self) -> Dict: - """ - Returns - ------- - Dict - The dataset in the required format with the prompts and/or images - read from the file. - """ - self._verify_file() - prompts, images = self._get_prompts_from_input_file() - if self.config.batch_size > len(prompts): - raise ValueError( - "Batch size cannot be larger than the number of available texts" - ) - dataset_json: Dict[str, Any] = {} - dataset_json["features"] = [{"name": "text_input"}] - dataset_json["rows"] = [] - - if self.config.batch_size == DEFAULT_BATCH_SIZE: - for prompt, image in zip(prompts, images): - content = {} - if prompt is not None: - content["text_input"] = prompt - if image is not None: - content["image"] = image - dataset_json["rows"].append({"row": content}) - else: - for _ in range(self.config.num_prompts): - content_array = [] - sampled_indices = random.sample( - range(len(prompts)), self.config.batch_size - ) - sampled_texts_images = [ - (prompts[i], images[i]) for i in sampled_indices - ] - - for prompt, image in sampled_texts_images: - content = {} - if prompt is not None: - content["text_input"] = prompt - if image is not None: - content["image"] = image - content_array.append(content) - dataset_json["rows"].append({"row": content_array}) - - return dataset_json - def _encode_images_in_input_dataset(self, input_file_dataset: Dict) -> Dict: for row in input_file_dataset["rows"]: if isinstance(row["row"], list): @@ -323,27 +198,6 @@ def _add_rows_to_generic_json( return generic_input_json - def _get_prompts_from_input_file(self) -> Tuple[List[str], List[str]]: - """ - Reads the input prompts from a JSONL file and returns a list of prompts. - - Returns - ------- - Tuple[List[str], List[str]] - A list of prompts and images read from the file. - """ - prompts = [] - images = [] - with open(self.config.input_filename, mode="r", newline=None) as file: - for line in file: - if line.strip(): - # None if not provided - prompt = load_json_str(line).get("text_input") - image = load_json_str(line).get("image") - prompts.append(prompt.strip() if prompt else prompt) - images.append(image.strip() if image else image) - return prompts, images - def _encode_image(self, filename: str) -> str: img = Image.open(filename) if img is None: @@ -403,12 +257,6 @@ def _create_synthetic_image(self) -> str: image_format=self.config.image_format, ) - def _verify_file(self) -> None: - if not self.config.input_filename.exists(): - raise FileNotFoundError( - f"The file '{self.config.input_filename}' does not exist." - ) - def _query_server(self, configured_url: str) -> Response: try: response = requests.get(configured_url) diff --git a/genai-perf/genai_perf/inputs/inputs.py b/genai-perf/genai_perf/inputs/inputs.py index 59c3a21d..c271aabb 100644 --- a/genai-perf/genai_perf/inputs/inputs.py +++ b/genai-perf/genai_perf/inputs/inputs.py @@ -61,17 +61,12 @@ def create_inputs(self) -> Dict: return json_in_pa_format def _check_for_valid_args(self) -> None: + self._check_for_supported_input_type() self._check_for_dataset_name_if_input_type_is_url() self._check_for_tokenzier_if_input_type_is_synthetic() self._check_for_valid_starting_index() self._check_for_valid_length() - def _verify_file(self) -> None: - if not self.config.input_filename.exists(): - raise FileNotFoundError( - f"The file '{self.config.input_filename}' does not exist." - ) - def _convert_generic_json_to_output_format(self, generic_dataset) -> Dict: converter = OutputFormatConverterFactory.create(self.config.output_format) return converter.convert(generic_dataset, self.config) @@ -81,6 +76,18 @@ def _write_json_to_file(self, json_in_pa_format: Dict) -> None: with open(str(filename), "w") as f: f.write(json.dumps(json_in_pa_format, indent=2)) + def _check_for_supported_input_type(self) -> None: + if self.config.output_format in [ + OutputFormat.OPENAI_EMBEDDINGS, + OutputFormat.RANKINGS, + OutputFormat.IMAGE_RETRIEVAL, + ]: + if self.config.input_type != PromptSource.FILE: + raise GenAIPerfException( + f"{self.config.output_format.to_lowercase()} only supports " + "a file as input source." + ) + def _check_for_dataset_name_if_input_type_is_url(self) -> None: if ( self.config.input_type == PromptSource.DATASET diff --git a/genai-perf/tests/test_file_input_retriever.py b/genai-perf/tests/test_file_input_retriever.py new file mode 100644 index 00000000..59d04644 --- /dev/null +++ b/genai-perf/tests/test_file_input_retriever.py @@ -0,0 +1,227 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from collections import namedtuple +from pathlib import Path +from unittest.mock import mock_open, patch + +import pytest +from genai_perf.inputs.file_input_retriever import FileInputRetriever +from genai_perf.inputs.input_constants import ModelSelectionStrategy +from genai_perf.inputs.inputs_config import InputsConfig +from PIL import Image + + +class TestFileInputRetriever: + @patch("pathlib.Path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data="\n".join( + [ + '{"text": "What production company co-owned by Kevin Loader and Rodger Michell produced My Cousin Rachel?"}', + '{"text": "Who served as the 1st Vice President of Colombia under El Libertador?"}', + '{"text": "Are the Barton Mine and Hermiston-McCauley Mine located in The United States of America?"}', + '{"text": "what state did they film daddy\'s home 2"}', + ] + ), + ) + def test_get_input_dataset_from_embeddings_file(self, mock_file, mock_exists): + batch_size = 3 + config = InputsConfig( + input_filename=Path("embeddings.jsonl"), + batch_size=batch_size, + num_prompts=100, + ) + file_retriever = FileInputRetriever(config) + dataset = file_retriever._get_input_dataset_from_embeddings_file() + + assert dataset is not None + assert len(dataset["rows"]) == 100 + for row in dataset["rows"]: + assert "row" in row + assert "payload" in row["row"] + payload = row["row"]["payload"] + assert "input" in payload + assert isinstance(payload["input"], list) + assert len(payload["input"]) == batch_size + + # Try error case where batch size is larger than the number of available texts + with pytest.raises( + ValueError, + match="Batch size cannot be larger than the number of available texts", + ): + config.batch_size = 5 + file_retriever._get_input_dataset_from_embeddings_file() + + def open_side_effects(self, filepath, *args, **kwargs): + queries_content = "\n".join( + [ + '{"text": "What production company co-owned by Kevin Loader and Rodger Michell produced My Cousin Rachel?"}', + '{"text": "Who served as the 1st Vice President of Colombia under El Libertador?"}', + '{"text": "Are the Barton Mine and Hermiston-McCauley Mine located in The United States of America?"}', + ] + ) + passages_content = "\n".join( + [ + '{"text": "Eric Anderson (sociologist) Eric Anderson (born January 18, 1968) is an American sociologist"}', + '{"text": "Kevin Loader is a British film and television producer. "}', + '{"text": "Barton Mine, also known as Net Lake Mine, is an abandoned surface and underground mine in Northeastern Ontario"}', + ] + ) + + file_contents = { + "queries.jsonl": queries_content, + "passages.jsonl": passages_content, + } + return mock_open( + read_data=file_contents.get(filepath, file_contents["queries.jsonl"]) + )() + + mock_open_obj = mock_open() + mock_open_obj.side_effect = open_side_effects + + @patch("pathlib.Path.exists", return_value=True) + @patch("builtins.open", mock_open_obj) + def test_get_input_dataset_from_rankings_file(self, mock_file): + queries_filename = Path("queries.jsonl") + passages_filename = Path("passages.jsonl") + batch_size = 2 + config = InputsConfig( + batch_size=batch_size, + num_prompts=100, + ) + file_retriever = FileInputRetriever(config) + dataset = file_retriever._get_input_dataset_from_rankings_files( + queries_filename=queries_filename, passages_filename=passages_filename + ) + + assert dataset is not None + assert len(dataset["rows"]) == 100 + for row in dataset["rows"]: + assert "row" in row + assert "payload" in row["row"] + payload = row["row"]["payload"] + assert "query" in payload + assert "passages" in payload + assert isinstance(payload["passages"], list) + assert len(payload["passages"]) == batch_size + + # Try error case where batch size is larger than the number of available texts + with pytest.raises( + ValueError, + match="Batch size cannot be larger than the number of available passages", + ): + config.batch_size = 5 + file_retriever._get_input_dataset_from_rankings_files( + queries_filename, passages_filename + ) + + def test_get_input_file_without_file_existing(self): + file_retriever = FileInputRetriever( + InputsConfig(input_filename=Path("prompt.txt")) + ) + with pytest.raises(FileNotFoundError): + file_retriever._get_input_dataset_from_file() + + @patch("pathlib.Path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"text_input": "single prompt"}\n', + ) + def test_get_input_file_with_single_prompt(self, mock_file, mock_exists): + expected_prompts = ["single prompt"] + file_retriever = FileInputRetriever( + InputsConfig( + model_name=["test_model_A"], + model_selection_strategy=ModelSelectionStrategy.ROUND_ROBIN, + input_filename=Path("prompt.txt"), + ) + ) + dataset = file_retriever._get_input_dataset_from_file() + + assert dataset is not None + assert len(dataset["rows"]) == len(expected_prompts) + for i, prompt in enumerate(expected_prompts): + assert dataset["rows"][i]["row"]["text_input"] == prompt + + @patch("pathlib.Path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"text_input": "prompt1"}\n{"text_input": "prompt2"}\n{"text_input": "prompt3"}\n', + ) + def test_get_input_file_with_multiple_prompts(self, mock_file, mock_exists): + expected_prompts = ["prompt1", "prompt2", "prompt3"] + file_retriever = FileInputRetriever( + InputsConfig( + model_name=["test_model_A"], + model_selection_strategy=ModelSelectionStrategy.ROUND_ROBIN, + input_filename=Path("prompt.txt"), + ) + ) + dataset = file_retriever._get_input_dataset_from_file() + + assert dataset is not None + assert len(dataset["rows"]) == len(expected_prompts) + for i, prompt in enumerate(expected_prompts): + assert dataset["rows"][i]["row"]["text_input"] == prompt + + @patch("pathlib.Path.exists", return_value=True) + @patch("PIL.Image.open", return_value=Image.new("RGB", (10, 10))) + @patch( + "builtins.open", + new_callable=mock_open, + read_data=( + '{"text_input": "prompt1", "image": "image1.png"}\n' + '{"text_input": "prompt2", "image": "image2.png"}\n' + '{"text_input": "prompt3", "image": "image3.png"}\n' + ), + ) + def test_get_input_file_with_multi_modal_data( + self, mock_exists, mock_image, mock_file + ): + file_retriever = FileInputRetriever( + InputsConfig( + model_name=["test_model_A"], + model_selection_strategy=ModelSelectionStrategy.ROUND_ROBIN, + input_filename=Path("prompt.txt"), + ) + ) + Data = namedtuple("Data", ["text_input", "image"]) + expected_data = [ + Data(text_input="prompt1", image="image1.png"), + Data(text_input="prompt2", image="image2.png"), + Data(text_input="prompt3", image="image3.png"), + ] + dataset = file_retriever._get_input_dataset_from_file() + + assert dataset is not None + assert len(dataset["rows"]) == len(expected_data) + for i, data in enumerate(expected_data): + assert dataset["rows"][i]["row"]["text_input"] == data.text_input + assert dataset["rows"][i]["row"]["image"] == data.image diff --git a/genai-perf/tests/test_input_embeddings.py b/genai-perf/tests/test_input_embeddings.py index 280538d8..2172385d 100644 --- a/genai-perf/tests/test_input_embeddings.py +++ b/genai-perf/tests/test_input_embeddings.py @@ -28,12 +28,12 @@ from unittest.mock import mock_open, patch import pytest +from genai_perf.inputs.file_input_retriever import FileInputRetriever from genai_perf.inputs.input_constants import ( ModelSelectionStrategy, OutputFormat, PromptSource, ) -from genai_perf.inputs.input_retriever_factory import InputRetrieverFactory from genai_perf.inputs.inputs import Inputs from genai_perf.inputs.inputs_config import InputsConfig @@ -59,9 +59,8 @@ def test_get_input_dataset_from_embeddings_file(self, mock_file, mock_exists): batch_size=batch_size, num_prompts=100, ) - input_retriever_factory = InputRetrieverFactory(config) - - dataset = input_retriever_factory._get_input_dataset_from_embeddings_file() + file_retriever = FileInputRetriever(config) + dataset = file_retriever._get_input_dataset_from_embeddings_file() assert dataset is not None assert len(dataset["rows"]) == 100 @@ -79,7 +78,7 @@ def test_get_input_dataset_from_embeddings_file(self, mock_file, mock_exists): match="Batch size cannot be larger than the number of available texts", ): config.batch_size = 5 - input_retriever_factory._get_input_dataset_from_embeddings_file() + file_retriever._get_input_dataset_from_embeddings_file() def test_convert_generic_json_to_openai_embeddings_format(self): generic_dataset = { diff --git a/genai-perf/tests/test_input_image_retrieval.py b/genai-perf/tests/test_input_image_retrieval.py index 9c34d4ca..4a8535da 100644 --- a/genai-perf/tests/test_input_image_retrieval.py +++ b/genai-perf/tests/test_input_image_retrieval.py @@ -27,7 +27,7 @@ class TestInputsImageRetrieval: return_value="data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/", ) @patch( - "genai_perf.inputs.input_retriever_factory.InputRetrieverFactory._get_input_dataset_from_file" + "genai_perf.inputs.input_retriever_factory.FileInputRetriever._get_input_dataset_from_file" ) def test_image_retrieval(self, mock_get_input, mock_encode_image, mock_write_json): mock_get_input.return_value = { @@ -80,7 +80,7 @@ def test_image_retrieval(self, mock_get_input, mock_encode_image, mock_write_jso @patch("genai_perf.inputs.inputs.Inputs._write_json_to_file") @patch( - "genai_perf.inputs.input_retriever_factory.InputRetrieverFactory._get_input_dataset_from_file" + "genai_perf.inputs.input_retriever_factory.FileInputRetriever._get_input_dataset_from_file" ) @patch( "genai_perf.inputs.input_retriever_factory.InputRetrieverFactory._encode_image" diff --git a/genai-perf/tests/test_input_rankings.py b/genai-perf/tests/test_input_rankings.py index 89ad1979..79fcb964 100644 --- a/genai-perf/tests/test_input_rankings.py +++ b/genai-perf/tests/test_input_rankings.py @@ -24,85 +24,17 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from pathlib import Path -from unittest.mock import mock_open, patch - -import pytest from genai_perf.inputs.input_constants import ( ModelSelectionStrategy, OutputFormat, PromptSource, ) -from genai_perf.inputs.input_retriever_factory import InputRetrieverFactory from genai_perf.inputs.inputs import Inputs from genai_perf.inputs.inputs_config import InputsConfig class TestInputsRankings: - def open_side_effects(self, filepath, *args, **kwargs): - queries_content = "\n".join( - [ - '{"text": "What production company co-owned by Kevin Loader and Rodger Michell produced My Cousin Rachel?"}', - '{"text": "Who served as the 1st Vice President of Colombia under El Libertador?"}', - '{"text": "Are the Barton Mine and Hermiston-McCauley Mine located in The United States of America?"}', - ] - ) - passages_content = "\n".join( - [ - '{"text": "Eric Anderson (sociologist) Eric Anderson (born January 18, 1968) is an American sociologist"}', - '{"text": "Kevin Loader is a British film and television producer. "}', - '{"text": "Barton Mine, also known as Net Lake Mine, is an abandoned surface and underground mine in Northeastern Ontario"}', - ] - ) - - file_contents = { - "queries.jsonl": queries_content, - "passages.jsonl": passages_content, - } - return mock_open( - read_data=file_contents.get(filepath, file_contents["queries.jsonl"]) - )() - - mock_open_obj = mock_open() - mock_open_obj.side_effect = open_side_effects - - @patch("pathlib.Path.exists", return_value=True) - @patch("builtins.open", mock_open_obj) - def test_get_input_dataset_from_rankings_file(self, mock_file): - queries_filename = Path("queries.jsonl") - passages_filename = Path("passages.jsonl") - batch_size = 2 - config = InputsConfig( - batch_size=batch_size, - num_prompts=100, - ) - input_retriever_factory = InputRetrieverFactory(config) - dataset = input_retriever_factory._get_input_dataset_from_rankings_files( - queries_filename=queries_filename, passages_filename=passages_filename - ) - - assert dataset is not None - assert len(dataset["rows"]) == 100 - for row in dataset["rows"]: - assert "row" in row - assert "payload" in row["row"] - payload = row["row"]["payload"] - assert "query" in payload - assert "passages" in payload - assert isinstance(payload["passages"], list) - assert len(payload["passages"]) == batch_size - - # Try error case where batch size is larger than the number of available texts - with pytest.raises( - ValueError, - match="Batch size cannot be larger than the number of available passages", - ): - config.batch_size = 5 - input_retriever_factory._get_input_dataset_from_rankings_files( - queries_filename, passages_filename - ) - def test_convert_generic_json_to_openai_rankings_format(self): generic_dataset = { "rows": [ diff --git a/genai-perf/tests/test_input_retriever_factory.py b/genai-perf/tests/test_input_retriever_factory.py index 8a8fd4d6..cb81b03f 100644 --- a/genai-perf/tests/test_input_retriever_factory.py +++ b/genai-perf/tests/test_input_retriever_factory.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple -from pathlib import Path -from unittest.mock import mock_open, patch +from unittest.mock import patch import pytest import responses from genai_perf.exceptions import GenAIPerfException from genai_perf.inputs import input_constants as ic -from genai_perf.inputs.input_constants import ModelSelectionStrategy, OutputFormat +from genai_perf.inputs.input_constants import OutputFormat from genai_perf.inputs.input_retriever_factory import InputRetrieverFactory from genai_perf.inputs.inputs_config import InputsConfig from genai_perf.inputs.synthetic_image_generator import ImageFormat from genai_perf.tokenizer import DEFAULT_TOKENIZER, get_tokenizer -from PIL import Image mocked_openorca_data = { "features": [ @@ -196,89 +193,3 @@ def test_inputs_with_defaults(self, default_configured_url): assert dataset_json is not None assert len(dataset_json["rows"]) == TEST_LENGTH - - def test_get_input_file_without_file_existing(self): - input_retriever_factory = InputRetrieverFactory( - InputsConfig(input_filename=Path("prompt.txt")) - ) - with pytest.raises(FileNotFoundError): - input_retriever_factory._get_input_dataset_from_file() - - @patch("pathlib.Path.exists", return_value=True) - @patch( - "builtins.open", - new_callable=mock_open, - read_data='{"text_input": "single prompt"}\n', - ) - def test_get_input_file_with_single_prompt(self, mock_file, mock_exists): - expected_prompts = ["single prompt"] - input_retriever_factory = InputRetrieverFactory( - InputsConfig( - model_name=["test_model_A"], - model_selection_strategy=ModelSelectionStrategy.ROUND_ROBIN, - input_filename=Path("prompt.txt"), - ) - ) - dataset = input_retriever_factory._get_input_dataset_from_file() - - assert dataset is not None - assert len(dataset["rows"]) == len(expected_prompts) - for i, prompt in enumerate(expected_prompts): - assert dataset["rows"][i]["row"]["text_input"] == prompt - - @patch("pathlib.Path.exists", return_value=True) - @patch( - "builtins.open", - new_callable=mock_open, - read_data='{"text_input": "prompt1"}\n{"text_input": "prompt2"}\n{"text_input": "prompt3"}\n', - ) - def test_get_input_file_with_multiple_prompts(self, mock_file, mock_exists): - expected_prompts = ["prompt1", "prompt2", "prompt3"] - input_retriever_factory = InputRetrieverFactory( - InputsConfig( - model_name=["test_model_A"], - model_selection_strategy=ModelSelectionStrategy.ROUND_ROBIN, - input_filename=Path("prompt.txt"), - ) - ) - dataset = input_retriever_factory._get_input_dataset_from_file() - - assert dataset is not None - assert len(dataset["rows"]) == len(expected_prompts) - for i, prompt in enumerate(expected_prompts): - assert dataset["rows"][i]["row"]["text_input"] == prompt - - @patch("pathlib.Path.exists", return_value=True) - @patch("PIL.Image.open", return_value=Image.new("RGB", (10, 10))) - @patch( - "builtins.open", - new_callable=mock_open, - read_data=( - '{"text_input": "prompt1", "image": "image1.png"}\n' - '{"text_input": "prompt2", "image": "image2.png"}\n' - '{"text_input": "prompt3", "image": "image3.png"}\n' - ), - ) - def test_get_input_file_with_multi_modal_data( - self, mock_exists, mock_image, mock_file - ): - input_retriever_factory = InputRetrieverFactory( - InputsConfig( - model_name=["test_model_A"], - model_selection_strategy=ModelSelectionStrategy.ROUND_ROBIN, - input_filename=Path("prompt.txt"), - ) - ) - Data = namedtuple("Data", ["text_input", "image"]) - expected_data = [ - Data(text_input="prompt1", image="image1.png"), - Data(text_input="prompt2", image="image2.png"), - Data(text_input="prompt3", image="image3.png"), - ] - dataset = input_retriever_factory._get_input_dataset_from_file() - - assert dataset is not None - assert len(dataset["rows"]) == len(expected_data) - for i, data in enumerate(expected_data): - assert dataset["rows"][i]["row"]["text_input"] == data.text_input - assert dataset["rows"][i]["row"]["image"] == data.image