Skip to content

Commit

Permalink
Removing InputFormat
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed Mar 6, 2024
1 parent 6cc3e43 commit 3606d30
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 44 deletions.
28 changes: 5 additions & 23 deletions src/c++/perf_analyzer/genai-pa/genai_pa/llm_inputs/llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ class InputType(Enum):
SYNTHETIC = auto()


class InputFormat(Enum):
OPENAI = auto()
TRTLLM = auto()
VLLM = auto()


class OutputFormat(Enum):
OPENAI_CHAT_COMPLETIONS = auto()
OPENAI_COMPLETIONS = auto()
Expand Down Expand Up @@ -70,7 +64,6 @@ class LlmInputs:
def create_llm_inputs(
cls,
input_type: InputType,
input_format: InputFormat,
output_format: OutputFormat,
dataset_name: str = "",
model_name: str = "",
Expand All @@ -88,8 +81,6 @@ def create_llm_inputs(
-------------------
input_type:
Specify how the input is received (file or URL)
input_format:
Specify the input format
output_format:
Specify the output format
Expand Down Expand Up @@ -123,9 +114,7 @@ def create_llm_inputs(
"Using file/synthetic to supply LLM Input is not supported at this time"
)

generic_dataset_json = LlmInputs._convert_input_dataset_to_generic_json(
input_format, dataset
)
generic_dataset_json = LlmInputs._convert_input_dataset_to_generic_json(dataset)

json_in_pa_format = LlmInputs._convert_generic_json_to_output_format(
output_format, generic_dataset_json, add_model_name, add_stream, model_name
Expand Down Expand Up @@ -181,23 +170,16 @@ def _download_dataset(cls, configured_url, starting_index, length) -> Response:
return dataset

@classmethod
def _convert_input_dataset_to_generic_json(
cls, input_format: InputFormat, dataset: Response
) -> Dict:
def _convert_input_dataset_to_generic_json(cls, dataset: Response) -> Dict:
dataset_json = dataset.json()
try:
LlmInputs._check_for_error_in_json_of_dataset(dataset_json)
except Exception as e:
raise GenAiPAException(e)

if input_format == InputFormat.OPENAI:
generic_dataset_json = LlmInputs._convert_openai_to_generic_input_json(
dataset_json
)
else:
raise GenAiPAException(
f"Input format {input_format} is not supported at this time"
)
generic_dataset_json = LlmInputs._convert_openai_to_generic_input_json(
dataset_json
)

return generic_dataset_json

Expand Down
25 changes: 4 additions & 21 deletions src/c++/perf_analyzer/genai-pa/tests/test_llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
import pytest
from genai_pa.constants import CNN_DAILY_MAIL, DEFAULT_INPUT_DATA_JSON, OPEN_ORCA
from genai_pa.exceptions import GenAiPAException
from genai_pa.llm_inputs.llm_inputs import (
InputFormat,
InputType,
LlmInputs,
OutputFormat,
)
from genai_pa.llm_inputs.llm_inputs import InputType, LlmInputs, OutputFormat


class TestLlmInputs:
Expand Down Expand Up @@ -103,7 +98,6 @@ def test_llm_inputs_error_in_server_response(self):
with pytest.raises(GenAiPAException):
_ = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
dataset_name=OPEN_ORCA,
output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS,
starting_index=LlmInputs.DEFAULT_STARTING_INDEX,
Expand All @@ -119,9 +113,7 @@ def test_llm_inputs_with_defaults(self, default_configured_url):
LlmInputs.DEFAULT_STARTING_INDEX,
LlmInputs.DEFAULT_LENGTH,
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(
input_format=InputFormat.OPENAI, dataset=dataset
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(dataset=dataset)

assert dataset_json is not None
assert len(dataset_json["rows"]) == LlmInputs.DEFAULT_LENGTH
Expand All @@ -140,9 +132,7 @@ def test_llm_inputs_with_non_default_length(self):
LlmInputs.DEFAULT_STARTING_INDEX,
length=(int(LlmInputs.DEFAULT_LENGTH / 2)),
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(
input_format=InputFormat.OPENAI, dataset=dataset
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(dataset=dataset)

assert dataset_json is not None
assert len(dataset_json["rows"]) == LlmInputs.DEFAULT_LENGTH / 2
Expand All @@ -156,9 +146,7 @@ def test_convert_default_json_to_pa_format(self, default_configured_url):
LlmInputs.DEFAULT_STARTING_INDEX,
LlmInputs.DEFAULT_LENGTH,
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(
input_format=InputFormat.OPENAI, dataset=dataset
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(dataset=dataset)
pa_json = LlmInputs._convert_generic_json_to_output_format(
output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS,
generic_dataset=dataset_json,
Expand All @@ -175,7 +163,6 @@ def test_create_openai_llm_inputs_cnn_dailymail(self):
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
dataset_name=CNN_DAILY_MAIL,
output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS,
)
Expand All @@ -191,7 +178,6 @@ def test_write_to_file(self):
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
dataset_name=OPEN_ORCA,
output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS,
add_model_name=True,
Expand All @@ -212,7 +198,6 @@ def test_create_openai_to_vllm(self):
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
output_format=OutputFormat.VLLM,
dataset_name=OPEN_ORCA,
add_model_name=False,
Expand All @@ -230,7 +215,6 @@ def test_create_openai_to_completions(self):
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
output_format=OutputFormat.OPENAI_COMPLETIONS,
dataset_name=OPEN_ORCA,
add_model_name=False,
Expand All @@ -248,7 +232,6 @@ def test_create_openai_to_trtllm(self):
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
output_format=OutputFormat.TRTLLM,
dataset_name=OPEN_ORCA,
add_model_name=False,
Expand Down

0 comments on commit 3606d30

Please sign in to comment.