diff --git a/model_analyzer/config/input/config_command_profile.py b/model_analyzer/config/input/config_command_profile.py index 328770fbe..bdce45027 100755 --- a/model_analyzer/config/input/config_command_profile.py +++ b/model_analyzer/config/input/config_command_profile.py @@ -17,6 +17,7 @@ import argparse import logging import os +from typing import Dict import numba.cuda import psutil @@ -1502,66 +1503,27 @@ def _autofill_values(self): } else: new_model["parameters"] = {} - if "batch_sizes" in model.parameters(): - new_model["parameters"].update( - {"batch_sizes": model.parameters()["batch_sizes"]} - ) - else: - new_model["parameters"].update({"batch_sizes": self.batch_sizes}) - - if "concurrency" in model.parameters(): - new_model["parameters"].update( - {"concurrency": model.parameters()["concurrency"]} - ) - else: - new_model["parameters"].update({"concurrency": self.concurrency}) - - if "periodic_concurrency" in model.parameters(): - new_model["parameters"].update( - { - "periodic_concurrency": model.parameters()[ - "periodic_concurrency" - ] - } - ) - else: - new_model["parameters"].update( - {"periodic_concurrency": self.periodic_concurrency} - ) - - if "request_rate" in model.parameters(): - new_model["parameters"].update( - {"request_rate": model.parameters()["request_rate"]} - ) - else: - new_model["parameters"].update({"request_rate": self.request_rate}) - - if "request_period" in model.parameters(): - new_model["parameters"].update( - {"request_period": model.parameters()["request_period"]} - ) - else: - new_model["parameters"].update( - {"request_period": self.request_rate} - ) - - if "text_input_length" in model.parameters(): - new_model["parameters"].update( - {"text_input_length": model.parameters()["text_input_length"]} - ) - else: - new_model["parameters"].update( - {"text_input_length": self.text_input_length} - ) - - if "max_token_count" in model.parameters(): - new_model["max_token_count"].update( - {"max_token_count": model.parameters()["max_token_count"]} - ) - else: - new_model["parameters"].update( - {"max_token_count": self.text_input_length} - ) + new_model["parameters"].update( + self._set_model_parameter(model, "batch_sizes") + ) + new_model["parameters"].update( + self._set_model_parameter(model, "concurrency") + ) + new_model["parameters"].update( + self._set_model_parameter(model, "periodic_concurrency") + ) + new_model["parameters"].update( + self._set_model_parameter(model, "request_rate") + ) + new_model["parameters"].update( + self._set_model_parameter(model, "request_period") + ) + new_model["parameters"].update( + self._set_model_parameter(model, "max_token_count") + ) + new_model["parameters"].update( + self._set_model_parameter(model, "text_input_length") + ) if ( new_model["parameters"]["request_rate"] @@ -1604,6 +1566,14 @@ def _autofill_values(self): new_profile_models[model.model_name()] = new_model self._fields["profile_models"].set_value(new_profile_models) + def _set_model_parameter( + self, model: ConfigModelProfileSpec, parameter_name: str + ) -> Dict: + if parameter_name in model.parameters(): + return {parameter_name: model.parameters()[parameter_name]} + else: + return {parameter_name: getattr(self, parameter_name)} + def _using_request_rate(self) -> bool: if self.request_rate or self.request_rate_search_enable: return True