Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlmInputs - trtllm output format support #492

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 105 additions & 23 deletions src/c++/perf_analyzer/genai-pa/genai_pa/llm_inputs/llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ class InputType(Enum):
SYNTHETIC = auto()


class InputFormat(Enum):
OPENAI = auto()
TRTLLM = auto()
VLLM = auto()


class OutputFormat(Enum):
OPENAI_CHAT_COMPLETIONS = auto()
OPENAI_COMPLETIONS = auto()
Expand All @@ -58,6 +52,8 @@ class LlmInputs:
DEFAULT_LENGTH = 100
MINIMUM_LENGTH = 1

DEFAULT_TRTLLM_MAX_TOKENS = 256

EMPTY_JSON_IN_VLLM_PA_FORMAT = {"data": []}
EMPTY_JSON_IN_TRTLLM_PA_FORMAT = {"data": []}
EMPTY_JSON_IN_OPENAI_PA_FORMAT = {"data": [{"payload": []}]}
Expand All @@ -68,7 +64,6 @@ class LlmInputs:
def create_llm_inputs(
cls,
input_type: InputType,
input_format: InputFormat,
output_format: OutputFormat,
dataset_name: str = "",
model_name: str = "",
Expand All @@ -86,8 +81,6 @@ def create_llm_inputs(
-------------------
input_type:
Specify how the input is received (file or URL)
input_format:
Specify the input format
output_format:
Specify the output format

Expand Down Expand Up @@ -121,9 +114,7 @@ def create_llm_inputs(
"Using file/synthetic to supply LLM Input is not supported at this time"
)

generic_dataset_json = LlmInputs._convert_input_dataset_to_generic_json(
input_format, dataset
)
generic_dataset_json = LlmInputs._convert_input_dataset_to_generic_json(dataset)

json_in_pa_format = LlmInputs._convert_generic_json_to_output_format(
output_format, generic_dataset_json, add_model_name, add_stream, model_name
Expand Down Expand Up @@ -179,23 +170,16 @@ def _download_dataset(cls, configured_url, starting_index, length) -> Response:
return dataset

@classmethod
def _convert_input_dataset_to_generic_json(
cls, input_format: InputFormat, dataset: Response
) -> Dict:
def _convert_input_dataset_to_generic_json(cls, dataset: Response) -> Dict:
dataset_json = dataset.json()
try:
LlmInputs._check_for_error_in_json_of_dataset(dataset_json)
except Exception as e:
raise GenAiPAException(e)

if input_format == InputFormat.OPENAI:
generic_dataset_json = LlmInputs._convert_openai_to_generic_input_json(
dataset_json
)
else:
raise GenAiPAException(
f"Input format {input_format} is not supported at this time"
)
generic_dataset_json = LlmInputs._convert_openai_to_generic_input_json(
debermudez marked this conversation as resolved.
Show resolved Hide resolved
dataset_json
)

return generic_dataset_json

Expand Down Expand Up @@ -254,6 +238,10 @@ def _convert_generic_json_to_output_format(
output_json = LlmInputs._convert_generic_json_to_vllm_format(
generic_dataset, add_model_name, add_stream, model_name
)
elif output_format == OutputFormat.TRTLLM:
output_json = LlmInputs._convert_generic_json_to_trtllm_format(
generic_dataset, add_model_name, add_stream, model_name
)
else:
raise GenAiPAException(
f"Output format {output_format} is not currently supported"
Expand Down Expand Up @@ -337,6 +325,32 @@ def _convert_generic_json_to_vllm_format(

return pa_json

@classmethod
def _convert_generic_json_to_trtllm_format(
cls,
dataset_json: Dict,
add_model_name: bool,
add_stream: bool,
model_name: str = "",
) -> Dict:
(
system_role_headers,
user_role_headers,
text_input_headers,
) = LlmInputs._determine_json_feature_roles(dataset_json)

pa_json = LlmInputs._populate_trtllm_output_json(
dataset_json,
system_role_headers,
user_role_headers,
text_input_headers,
add_model_name,
add_stream,
model_name,
)

return pa_json

@classmethod
def _write_json_to_file(cls, json_in_pa_format: Dict):
try:
Expand Down Expand Up @@ -469,6 +483,42 @@ def _populate_vllm_output_json(

return pa_json

@classmethod
def _populate_trtllm_output_json(
cls,
dataset_json: Dict,
system_role_headers: List[str],
user_role_headers: List[str],
text_input_headers: List[str],
add_model_name: bool,
add_stream: bool,
model_name: str = "",
) -> Dict:
pa_json = LlmInputs._create_empty_trtllm_pa_json()

for index, entry in enumerate(dataset_json["rows"]):
pa_json["data"].append({"text_input": []})

for header, content in entry.items():
new_text_input = LlmInputs._create_new_text_input(
header,
system_role_headers,
user_role_headers,
text_input_headers,
content,
)

pa_json = LlmInputs._add_new_text_input_to_json(
pa_json, index, new_text_input
)

pa_json = LlmInputs._add_required_tags_to_trtllm_json(pa_json, index)
pa_json = LlmInputs._add_optional_tags_to_trtllm_json(
pa_json, index, add_model_name, add_stream, model_name
)

return pa_json

@classmethod
def _create_empty_openai_pa_json(cls) -> Dict:
empty_pa_json = deepcopy(LlmInputs.EMPTY_JSON_IN_OPENAI_PA_FORMAT)
Expand All @@ -481,6 +531,12 @@ def _create_empty_vllm_pa_json(cls) -> Dict:

return empty_pa_json

@classmethod
def _create_empty_trtllm_pa_json(cls) -> Dict:
empty_pa_json = deepcopy(LlmInputs.EMPTY_JSON_IN_TRTLLM_PA_FORMAT)

return empty_pa_json

@classmethod
def _create_new_openai_chat_completions_message(
cls,
Expand Down Expand Up @@ -603,6 +659,32 @@ def _add_optional_tags_to_vllm_json(

return pa_json

@classmethod
def _add_optional_tags_to_trtllm_json(
cls,
pa_json: Dict,
index: int,
add_model_name: bool,
add_stream: bool,
model_name: str = "",
) -> Dict:
if add_model_name:
pa_json["data"][index]["model"] = model_name
if add_stream:
pa_json["data"][index]["stream"] = [True]

return pa_json

@classmethod
def _add_required_tags_to_trtllm_json(
cls,
pa_json: Dict,
index: int,
) -> Dict:
pa_json["data"][index]["max_tokens"] = LlmInputs.DEFAULT_TRTLLM_MAX_TOKENS

return pa_json

@classmethod
def _check_for_dataset_name_if_input_type_is_url(
cls, input_type: InputType, dataset_name: str
Expand Down
41 changes: 21 additions & 20 deletions src/c++/perf_analyzer/genai-pa/tests/test_llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
import pytest
from genai_pa.constants import CNN_DAILY_MAIL, DEFAULT_INPUT_DATA_JSON, OPEN_ORCA
from genai_pa.exceptions import GenAiPAException
from genai_pa.llm_inputs.llm_inputs import (
InputFormat,
InputType,
LlmInputs,
OutputFormat,
)
from genai_pa.llm_inputs.llm_inputs import InputType, LlmInputs, OutputFormat


class TestLlmInputs:
Expand Down Expand Up @@ -103,7 +98,6 @@ def test_llm_inputs_error_in_server_response(self):
with pytest.raises(GenAiPAException):
_ = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
dataset_name=OPEN_ORCA,
output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS,
starting_index=LlmInputs.DEFAULT_STARTING_INDEX,
Expand All @@ -119,9 +113,7 @@ def test_llm_inputs_with_defaults(self, default_configured_url):
LlmInputs.DEFAULT_STARTING_INDEX,
LlmInputs.DEFAULT_LENGTH,
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(
input_format=InputFormat.OPENAI, dataset=dataset
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(dataset=dataset)

assert dataset_json is not None
assert len(dataset_json["rows"]) == LlmInputs.DEFAULT_LENGTH
Expand All @@ -140,9 +132,7 @@ def test_llm_inputs_with_non_default_length(self):
LlmInputs.DEFAULT_STARTING_INDEX,
length=(int(LlmInputs.DEFAULT_LENGTH / 2)),
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(
input_format=InputFormat.OPENAI, dataset=dataset
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(dataset=dataset)

assert dataset_json is not None
assert len(dataset_json["rows"]) == LlmInputs.DEFAULT_LENGTH / 2
Expand All @@ -156,9 +146,7 @@ def test_convert_default_json_to_pa_format(self, default_configured_url):
LlmInputs.DEFAULT_STARTING_INDEX,
LlmInputs.DEFAULT_LENGTH,
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(
input_format=InputFormat.OPENAI, dataset=dataset
)
dataset_json = LlmInputs._convert_input_dataset_to_generic_json(dataset=dataset)
pa_json = LlmInputs._convert_generic_json_to_output_format(
output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS,
generic_dataset=dataset_json,
Expand All @@ -175,7 +163,6 @@ def test_create_openai_llm_inputs_cnn_dailymail(self):
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
dataset_name=CNN_DAILY_MAIL,
output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS,
)
Expand All @@ -191,7 +178,6 @@ def test_write_to_file(self):
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
dataset_name=OPEN_ORCA,
output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS,
add_model_name=True,
Expand All @@ -212,7 +198,6 @@ def test_create_openai_to_vllm(self):
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
output_format=OutputFormat.VLLM,
dataset_name=OPEN_ORCA,
add_model_name=False,
Expand All @@ -230,7 +215,6 @@ def test_create_openai_to_completions(self):
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
input_format=InputFormat.OPENAI,
output_format=OutputFormat.OPENAI_COMPLETIONS,
dataset_name=OPEN_ORCA,
add_model_name=False,
Expand All @@ -241,3 +225,20 @@ def test_create_openai_to_completions(self):

assert pa_json is not None
assert len(pa_json["data"][0]["payload"]) == LlmInputs.DEFAULT_LENGTH

def test_create_openai_to_trtllm(self):
"""
Test conversion of openai to trtllm
"""
pa_json = LlmInputs.create_llm_inputs(
input_type=InputType.URL,
output_format=OutputFormat.TRTLLM,
dataset_name=OPEN_ORCA,
add_model_name=False,
add_stream=True,
)

os.remove(DEFAULT_INPUT_DATA_JSON)

assert pa_json is not None
assert len(pa_json["data"]) == LlmInputs.DEFAULT_LENGTH
Loading