diff --git a/model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py b/model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py index 85c8cdb68..61327c15c 100755 --- a/model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py +++ b/model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py @@ -95,13 +95,16 @@ def get_configs(self) -> Generator[RunConfig, None, None]: logger.info("") yield from self._execute_optuna_search() logger.info("") - logger.info( - "Done with Optuna mode search. Gathering concurrency sweep measurements for reports" - ) - logger.info("") - yield from self._sweep_concurrency_over_top_results() - logger.info("") - logger.info("Done gathering concurrency sweep measurements for reports") + if self._config.concurrency_sweep_disable: + logger.info("Done with Optuna mode search.") + else: + logger.info( + "Done with Optuna mode search. Gathering concurrency sweep measurements for reports" + ) + logger.info("") + yield from self._sweep_concurrency_over_top_results() + logger.info("") + logger.info("Done gathering concurrency sweep measurements for reports") logger.info("") def _execute_optuna_search(self) -> Generator[RunConfig, None, None]: diff --git a/model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py b/model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py index 463999249..3ac34daff 100755 --- a/model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py +++ b/model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py @@ -99,13 +99,16 @@ def get_configs(self) -> Generator[RunConfig, None, None]: logger.info("") yield from self._execute_quick_search() logger.info("") - logger.info( - "Done with quick mode search. Gathering concurrency sweep measurements for reports" - ) - logger.info("") - yield from self._sweep_concurrency_over_top_results() - logger.info("") - logger.info("Done gathering concurrency sweep measurements for reports") + if self._config.concurrency_sweep_disable: + logger.info("Done with quick mode search.") + else: + logger.info( + "Done with quick mode search. Gathering concurrency sweep measurements for reports" + ) + logger.info("") + yield from self._sweep_concurrency_over_top_results() + logger.info("") + logger.info("Done gathering concurrency sweep measurements for reports") logger.info("") def _execute_quick_search(self) -> Generator[RunConfig, None, None]: diff --git a/model_analyzer/config/input/config_command_profile.py b/model_analyzer/config/input/config_command_profile.py index 47fcc403d..76314dc2a 100755 --- a/model_analyzer/config/input/config_command_profile.py +++ b/model_analyzer/config/input/config_command_profile.py @@ -43,6 +43,7 @@ DEFAULT_CHECKPOINT_DIRECTORY, DEFAULT_CLIENT_PROTOCOL, DEFAULT_COLLECT_CPU_METRICS, + DEFAULT_CONCURRENCY_SWEEP_DISABLE, DEFAULT_DURATION_SECONDS, DEFAULT_EXPORT_PATH, DEFAULT_FILENAME_MODEL_GPU, @@ -1012,6 +1013,16 @@ def _add_run_search_configs(self): description="Enables the searching of request rate (instead of concurrency).", ) ) + self._add_config( + ConfigField( + "concurrency_sweep_disable", + flags=["--concurrency-sweep-disable"], + field_type=ConfigPrimitive(bool), + parser_args={"action": "store_true"}, + default_value=DEFAULT_CONCURRENCY_SWEEP_DISABLE, + description="Disables the sweeping of concurrencies for the top-N models after quick/optuna search completion.", + ) + ) def _add_triton_configs(self): """ diff --git a/model_analyzer/config/input/config_defaults.py b/model_analyzer/config/input/config_defaults.py index 881e8945a..2e43db3dc 100755 --- a/model_analyzer/config/input/config_defaults.py +++ b/model_analyzer/config/input/config_defaults.py @@ -60,6 +60,7 @@ DEFAULT_OPTUNA_MAX_TRIALS = 200 DEFAULT_USE_CONCURRENCY_FORMULA = False DEFAULT_REQUEST_RATE_SEARCH_ENABLE = False +DEFAULT_CONCURRENCY_SWEEP_DISABLE = False DEFAULT_TRITON_LAUNCH_MODE = "local" DEFAULT_TRITON_DOCKER_IMAGE = "nvcr.io/nvidia/tritonserver:24.04-py3" DEFAULT_TRITON_HTTP_ENDPOINT = "localhost:8000" diff --git a/tests/test_cli.py b/tests/test_cli.py index d4bfde6f6..f055dce86 100755 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -66,6 +66,8 @@ def get_test_options(): OptionStruct("bool", "profile","--skip-detailed-reports"), OptionStruct("bool", "profile","--always-report-gpu-metrics"), OptionStruct("bool", "profile","--use-concurrency-formula"), + OptionStruct("bool", "profile","--concurrency-sweep-disable"), + #Int/Float options # Options format: