Skip to content

Commit

Permalink
Add synthetic data retriever class (#77)
Browse files Browse the repository at this point in the history
* add synthetic data retriever

* fix pytest
  • Loading branch information
nv-hwoo committed Sep 23, 2024
1 parent ec6dd46 commit 4ad1277
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 0 deletions.
64 changes: 64 additions & 0 deletions genai-perf/genai_perf/inputs/synthetic_data_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


from typing import Any, Dict, List

from genai_perf.inputs.input_constants import OutputFormat
from genai_perf.inputs.synthetic_image_generator import SyntheticImageGenerator
from genai_perf.inputs.synthetic_prompt_generator import SyntheticPromptGenerator


class SyntheticDataRetriever:
"""
A data retriever class that handles generation of synthetic data.
"""

def __init__(self, config):
self.config = config

def retrieve_data(self) -> List[Dict[str, Any]]:
synthetic_dataset = []
for _ in range(self.config.num_prompts):
prompt = SyntheticPromptGenerator.create_synthetic_prompt(
self.config.tokenizer,
self.config.prompt_tokens_mean,
self.config.prompt_tokens_stddev,
)
data = {"text_input": prompt}

if self.config.output_format == OutputFormat.OPENAI_VISION:
image = SyntheticImageGenerator.create_synthetic_image(
image_width_mean=self.config.image_width_mean,
image_width_stddev=self.config.image_width_stddev,
image_height_mean=self.config.image_height_mean,
image_height_stddev=self.config.image_height_stddev,
image_format=self.config.image_format,
)
data["image"] = image

synthetic_dataset.append(data)
return synthetic_dataset
83 changes: 83 additions & 0 deletions genai-perf/tests/test_synthetic_data_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch

import pytest
from genai_perf.inputs.input_constants import OutputFormat
from genai_perf.inputs.inputs_config import InputsConfig
from genai_perf.inputs.synthetic_data_retriever import SyntheticDataRetriever


class TestSyntheticDataRetriever:

@patch(
"genai_perf.inputs.synthetic_data_retriever.SyntheticPromptGenerator.create_synthetic_prompt",
return_value="test prompt",
)
@pytest.mark.parametrize(
"output_format",
[
(OutputFormat.OPENAI_COMPLETIONS),
(OutputFormat.OPENAI_CHAT_COMPLETIONS),
(OutputFormat.VLLM),
(OutputFormat.TENSORRTLLM),
],
)
def test_synthetic_text(self, mock_prompt, output_format):
config = InputsConfig(
num_prompts=3,
output_format=output_format,
)
synthetic_retriever = SyntheticDataRetriever(config)
dataset = synthetic_retriever.retrieve_data()

assert len(dataset) == 3
assert dataset == [
{"text_input": "test prompt"},
{"text_input": "test prompt"},
{"text_input": "test prompt"},
]

@patch(
"genai_perf.inputs.synthetic_data_retriever.SyntheticPromptGenerator.create_synthetic_prompt",
return_value="test prompt",
)
@patch(
"genai_perf.inputs.synthetic_data_retriever.SyntheticImageGenerator.create_synthetic_image",
return_value="_base64_encoding",
)
def test_synthetic_text_and_image(self, mock_prompt, mock_image):
config = InputsConfig(
num_prompts=3,
output_format=OutputFormat.OPENAI_VISION,
)
synthetic_retriever = SyntheticDataRetriever(config)
dataset = synthetic_retriever.retrieve_data()

assert len(dataset) == 3
assert dataset == [
{
"text_input": "test prompt",
"image": "_base64_encoding",
},
{
"text_input": "test prompt",
"image": "_base64_encoding",
},
{
"text_input": "test prompt",
"image": "_base64_encoding",
},
]

0 comments on commit 4ad1277

Please sign in to comment.