Skip to content

Commit

Permalink
Add triton generate handling
Browse files Browse the repository at this point in the history
  • Loading branch information
IzzyPutterman committed Dec 19, 2024
1 parent a264168 commit fd52cc3
Showing 1 changed file with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from genai_perf.inputs.converters.base_converter import BaseConverter
from genai_perf.inputs.input_constants import DEFAULT_OUTPUT_TOKENS_MEAN
from genai_perf.inputs.inputs_config import InputsConfig
from genai_perf.inputs.retrievers.generic_dataset import GenericDataset
from genai_perf.inputs.retrievers.generic_dataset import DataRow, GenericDataset
from genai_perf.utils import sample_bounded_normal


Expand All @@ -56,6 +56,7 @@ def convert(
"text_input": prompt,
}
self._add_request_params(payload, config)
self._override_extra(payload, row)
request_body["data"].append({"payload": [payload]})

return request_body
Expand All @@ -73,3 +74,8 @@ def _add_request_params(self, payload: Dict, config: InputsConfig) -> None:
)
for key, value in config.extra_inputs.items():
payload[key] = value

def _override_extra(self, payload: Dict, row: DataRow) -> None:
for key, value in row.extra_args.items():
if key == "max_tokens":
payload["max_tokens"] = value

0 comments on commit fd52cc3

Please sign in to comment.