Skip to content

Commit

Permalink
Refactor converters
Browse files Browse the repository at this point in the history
  • Loading branch information
lkomali committed Jan 15, 2025
1 parent 6964ee0 commit 6096b2c
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 73 deletions.
25 changes: 21 additions & 4 deletions genai-perf/genai_perf/inputs/converters/base_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from genai_perf.exceptions import GenAIPerfException
from genai_perf.inputs.input_constants import ModelSelectionStrategy
from genai_perf.inputs.inputs_config import InputsConfig
from genai_perf.inputs.retrievers.generic_dataset import GenericDataset
from genai_perf.inputs.retrievers.generic_dataset import DataRow, GenericDataset


class BaseConverter:
Expand Down Expand Up @@ -71,8 +71,25 @@ def _add_request_params(
for key, value in config.extra_inputs.items():
payload[key] = value

def _add_payload_params(
self, payload: Dict[Any, Any], optional_data: Dict[Any, Any]
) -> None:
def _add_payload_params(self, payload: Dict[Any, Any], optional_data) -> None:
for key, value in optional_data.items():
payload[key] = value

def _add_extra_params(
self, payload: Dict[Any, Any], config: InputsConfig, row: DataRow
) -> None:
self._add_request_params(payload, config)
self._add_payload_params(payload, row.optional_data)

def _finalize_payload(
self, payload: Dict[Any, Any], row, triton_format=False
) -> Dict[str, Any]:
record: Dict[str, Any] = {}
if not triton_format:
record["payload"] = [payload]
else:
record.update(payload)
if row.timestamp:
record["timestamp"] = [row.timestamp]

return record
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ def convert(
"input": [{"type": "image_url", "url": img} for img in row.images]
}
self._add_payload_params(payload, row.optional_data)
record: Dict[str, Any] = {"payload": [payload]}
if row.timestamp:
record["timestamp"] = [row.timestamp]

request_body["data"].append(record)
request_body["data"].append(self._finalize_payload(payload, row))

return request_body
9 changes: 2 additions & 7 deletions genai-perf/genai_perf/inputs/converters/nvclip_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,7 @@ def convert(
"input": input_items,
}

self._add_request_params(payload, config)
self._add_payload_params(payload, row.optional_data)
record: Dict[str, Any] = {"payload": [payload]}
if row.timestamp:
record["timestamp"] = [row.timestamp]

request_body["data"].append(record)
self._add_extra_params(payload, config, row)
request_body["data"].append(self._finalize_payload(payload, row))

return request_body
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ def convert(
for file_data in generic_dataset.files_data.values():
for index, row in enumerate(file_data.rows):
payload = self._create_payload(index, row, config)
record: Dict[str, Any] = {"payload": [payload]}
if row.timestamp:
record["timestamp"] = [row.timestamp]

request_body["data"].append(record)
request_body["data"].append(self._finalize_payload(payload, row))

return request_body

Expand All @@ -86,8 +82,7 @@ def _create_payload(
],
}

self._add_request_params(payload, config)
self._add_payload_params(payload, row.optional_data)
self._add_extra_params(payload, config, row)
return payload

def _retrieve_content(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,8 @@ def convert(
"model": model_name,
"prompt": prompt,
}
self._add_request_params(payload, config)
self._add_payload_params(payload, row.optional_data)
record: Dict[str, Any] = {"payload": [payload]}
if row.timestamp:
record["timestamp"] = [row.timestamp]

request_body["data"].append(record)
self._add_extra_params(payload, config, row)
request_body["data"].append(self._finalize_payload(payload, row))

return request_body

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,7 @@ def convert(
"input": row.texts,
}

self._add_request_params(payload, config)
self._add_payload_params(payload, row.optional_data)
record: Dict[str, Any] = {"payload": [payload]}
if row.timestamp:
record["timestamp"] = [row.timestamp]

request_body["data"].append(record)
self._add_extra_params(payload, config, row)
request_body["data"].append(self._finalize_payload(payload, row))

return request_body
9 changes: 2 additions & 7 deletions genai-perf/genai_perf/inputs/converters/rankings_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,8 @@ def convert(
"model": model_name,
}

self._add_request_params(payload, config)
self._add_payload_params(payload, passage_entry.optional_data)
record: Dict[str, Any] = {"payload": [payload]}
if passage_entry.timestamp:
record["timestamp"] = [passage_entry.timestamp]

request_body["data"].append(record)
self._add_extra_params(payload, config, passage_entry)
request_body["data"].append(self._finalize_payload(payload, passage_entry))

return request_body

Expand Down
11 changes: 4 additions & 7 deletions genai-perf/genai_perf/inputs/converters/tensorrtllm_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,10 @@ def convert(
"max_tokens": [DEFAULT_TENSORRTLLM_MAX_TOKENS], # default
}

self._add_request_params(payload, config)
self._add_payload_params(payload, row.optional_data)
record: Dict[str, Any] = payload
if row.timestamp:
record["timestamp"] = [row.timestamp]

request_body["data"].append(record)
self._add_extra_params(payload, config, row)
request_body["data"].append(
self._finalize_payload(payload, row, triton_format=True)
)

return request_body

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,10 @@ def convert(
"request_output_len": [DEFAULT_TENSORRTLLM_MAX_TOKENS],
}

self._add_request_params(payload, config)
self._add_payload_params(payload, row.optional_data)
record: Dict[str, Any] = payload
if row.timestamp:
record["timestamp"] = [row.timestamp]

request_body["data"].append(record)
self._add_extra_params(payload, config, row)
request_body["data"].append(
self._finalize_payload(payload, row, triton_format=True)
)

return request_body

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def convert(
payload = {
"text_input": prompt,
}
self._add_request_params(payload, config)
self._add_payload_params(payload, row.optional_data)
record: Dict[str, Any] = {"payload": [payload]}
if row.timestamp:
record["timestamp"] = [row.timestamp]

request_body["data"].append(record)
self._add_extra_params(payload, config, row)
request_body["data"].append(
self._finalize_payload(
payload,
row,
)
)

return request_body

Expand Down
12 changes: 4 additions & 8 deletions genai-perf/genai_perf/inputs/converters/vllm_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,10 @@ def convert(
"text_input": text,
"exclude_input_in_output": [True], # default
}
optional_data = row.optional_data
self._add_request_params(payload, config)
self._add_payload_params(payload, optional_data)
record: Dict[str, Any] = payload
if row.timestamp:
record["timestamp"] = [row.timestamp]

request_body["data"].append(record)
self._add_extra_params(payload, config, row)
request_body["data"].append(
self._finalize_payload(payload, row, triton_format=True)
)

return request_body

Expand Down

0 comments on commit 6096b2c

Please sign in to comment.