Skip to content

Commit

Permalink
Fixing output format for all endpoints (#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf authored and debermudez committed Mar 12, 2024
1 parent 5fcf512 commit 15c2b8c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
32 changes: 22 additions & 10 deletions src/c++/perf_analyzer/genai-pa/genai_pa/llm_inputs/llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class LlmInputs:

EMPTY_JSON_IN_VLLM_PA_FORMAT = {"data": []}
EMPTY_JSON_IN_TRTLLM_PA_FORMAT = {"data": []}
EMPTY_JSON_IN_OPENAI_PA_FORMAT = {"data": [{"payload": []}]}
EMPTY_JSON_IN_OPENAI_PA_FORMAT = {"data": []}

dataset_url_map = {OPEN_ORCA: OPEN_ORCA_URL, CNN_DAILY_MAIL: CNN_DAILYMAIL_URL}

Expand Down Expand Up @@ -396,7 +396,8 @@ def _populate_openai_chat_completions_output_json(
pa_json = LlmInputs._create_empty_openai_pa_json()

for index, entry in enumerate(dataset_json["rows"]):
pa_json["data"][0]["payload"].append({"messages": []})
pa_json["data"].append({"payload": []})
pa_json["data"][index]["payload"].append({"messages": []})

for header, content in entry.items():
new_message = LlmInputs._create_new_openai_chat_completions_message(
Expand Down Expand Up @@ -427,7 +428,8 @@ def _populate_openai_completions_output_json(
pa_json = LlmInputs._create_empty_openai_pa_json()

for index, entry in enumerate(dataset_json["rows"]):
pa_json["data"][0]["payload"].append({"prompt": []})
pa_json["data"].append({"payload": []})
pa_json["data"][index]["payload"].append({"prompt": [""]})

for header, content in entry.items():
new_prompt = LlmInputs._create_new_prompt(
Expand Down Expand Up @@ -460,7 +462,7 @@ def _populate_vllm_output_json(
pa_json = LlmInputs._create_empty_vllm_pa_json()

for index, entry in enumerate(dataset_json["rows"]):
pa_json["data"].append({"text_input": []})
pa_json["data"].append({"text_input": [""]})

for header, content in entry.items():
new_text_input = LlmInputs._create_new_text_input(
Expand Down Expand Up @@ -495,7 +497,7 @@ def _populate_trtllm_output_json(
pa_json = LlmInputs._create_empty_trtllm_pa_json()

for index, entry in enumerate(dataset_json["rows"]):
pa_json["data"].append({"text_input": []})
pa_json["data"].append({"text_input": [""]})

for header, content in entry.items():
new_text_input = LlmInputs._create_new_text_input(
Expand Down Expand Up @@ -603,7 +605,7 @@ def _add_new_message_to_json(
cls, pa_json: Dict, index: int, new_message: Optional[Dict]
) -> Dict:
if new_message:
pa_json["data"][0]["payload"][index]["messages"].append(new_message)
pa_json["data"][index]["payload"][0]["messages"].append(new_message)

return pa_json

Expand All @@ -612,7 +614,12 @@ def _add_new_text_input_to_json(
cls, pa_json: Dict, index: int, new_text_input: str
) -> Dict:
if new_text_input:
pa_json["data"][index]["text_input"].append(new_text_input)
if pa_json["data"][index]["text_input"][0]:
pa_json["data"][index]["text_input"][0] = (
pa_json["data"][index]["text_input"][0] + f" {new_text_input}"
)
else:
pa_json["data"][index]["text_input"][0] = new_text_input

return pa_json

Expand All @@ -621,7 +628,12 @@ def _add_new_prompt_to_json(
cls, pa_json: Dict, index: int, new_prompt: str
) -> Dict:
if new_prompt:
pa_json["data"][0]["payload"][index]["prompt"].append(new_prompt)
if pa_json["data"][index]["payload"][0]["prompt"][0]:
pa_json["data"][index]["payload"][0]["prompt"][0] = (
pa_json["data"][index]["payload"][0]["prompt"][0] + f" {new_prompt}"
)
else:
pa_json["data"][index]["payload"][0]["prompt"][0] = new_prompt

return pa_json

Expand All @@ -635,9 +647,9 @@ def _add_optional_tags_to_openai_json(
model_name: str = "",
) -> Dict:
if add_model_name:
pa_json["data"][0]["payload"][index]["model"] = model_name
pa_json["data"][index]["payload"][0]["model"] = model_name
if add_stream:
pa_json["data"][0]["payload"][index]["stream"] = [True]
pa_json["data"][index]["payload"][0]["stream"] = [True]

return pa_json

Expand Down
7 changes: 4 additions & 3 deletions src/c++/perf_analyzer/genai-pa/tests/test_llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_convert_default_json_to_pa_format(self, default_configured_url):
)

assert pa_json is not None
assert len(pa_json["data"][0]["payload"]) == LlmInputs.DEFAULT_LENGTH
assert len(pa_json["data"]) == LlmInputs.DEFAULT_LENGTH

def test_create_openai_llm_inputs_cnn_dailymail(self):
"""
Expand All @@ -170,7 +170,7 @@ def test_create_openai_llm_inputs_cnn_dailymail(self):
os.remove(DEFAULT_INPUT_DATA_JSON)

assert pa_json is not None
assert len(pa_json["data"][0]["payload"]) == LlmInputs.DEFAULT_LENGTH
assert len(pa_json["data"]) == LlmInputs.DEFAULT_LENGTH

def test_write_to_file(self):
"""
Expand All @@ -180,6 +180,7 @@ def test_write_to_file(self):
input_type=InputType.URL,
dataset_name=OPEN_ORCA,
output_format=OutputFormat.OPENAI_CHAT_COMPLETIONS,
model_name="open_orca",
add_model_name=True,
add_stream=True,
)
Expand Down Expand Up @@ -224,7 +225,7 @@ def test_create_openai_to_completions(self):
os.remove(DEFAULT_INPUT_DATA_JSON)

assert pa_json is not None
assert len(pa_json["data"][0]["payload"]) == LlmInputs.DEFAULT_LENGTH
assert len(pa_json["data"]) == LlmInputs.DEFAULT_LENGTH

def test_create_openai_to_trtllm(self):
"""
Expand Down

0 comments on commit 15c2b8c

Please sign in to comment.