diff --git a/model_analyzer/analyzer.py b/model_analyzer/analyzer.py index a759a6c24..fa7a9c49c 100755 --- a/model_analyzer/analyzer.py +++ b/model_analyzer/analyzer.py @@ -84,6 +84,7 @@ def __init__( ) self._search_parameters: Dict[str, SearchParameters] = {} + self._composing_search_parameters: Dict[str, SearchParameters] = {} def profile( self, client: TritonClient, gpus: List[GPUDevice], mode: str, verbose: bool @@ -119,6 +120,7 @@ def profile( self._create_metrics_manager(client, gpus) self._create_model_manager(client, gpus) self._populate_search_parameters() + self._populate_composing_search_parameters() if self._config.triton_launch_mode == "remote": self._warn_if_other_models_loaded_on_remote_server(client) @@ -428,9 +430,12 @@ def _populate_search_parameters(self): model.model_config_parameters(), is_bls_model=bool(self._config.bls_composing_models), ) + + def _populate_composing_search_parameters(self): for model in self._config.bls_composing_models: - self._search_parameters[model.model_name()] = SearchParameters( + self._composing_search_parameters[model.model_name()] = SearchParameters( self._config, model.parameters(), model.model_config_parameters(), + is_composing_model=True, ) diff --git a/model_analyzer/config/generate/search_parameters.py b/model_analyzer/config/generate/search_parameters.py index 41986f410..9c2acab3d 100755 --- a/model_analyzer/config/generate/search_parameters.py +++ b/model_analyzer/config/generate/search_parameters.py @@ -46,12 +46,14 @@ def __init__( parameters: Dict[str, Any] = {}, model_config_parameters: Dict[str, Any] = {}, is_bls_model: bool = False, + is_composing_model: bool = False, ): self._config = config self._parameters = parameters self._model_config_parameters = model_config_parameters self._search_parameters: Dict[str, SearchParameter] = {} self._is_bls_model = is_bls_model + self._is_composing_model = is_composing_model self._populate_search_parameters() @@ -123,7 +125,7 @@ def _populate_search_parameters(self) -> None: def _populate_parameters(self) -> None: self._populate_batch_sizes() - if not self._is_bls_model: + if not self._is_composing_model: self._populate_concurrency() # TODO: Populate request rate - TMA-1903 diff --git a/tests/test_search_parameters.py b/tests/test_search_parameters.py index 80f0541a3..0ff4375b1 100755 --- a/tests/test_search_parameters.py +++ b/tests/test_search_parameters.py @@ -482,10 +482,11 @@ def test_search_parameter_creation_bls_default(self): analyzer = Analyzer(config, MagicMock(), MagicMock(), MagicMock()) analyzer._populate_search_parameters() + analyzer._populate_composing_search_parameters() # ADD_SUB # ===================================================================== - # The top level model of a BLS can only search instance count + # The top level model of a BLS does not search max batch size (always 1) # max_batch_size max_batch_size = analyzer._search_parameters["add_sub"].get_parameter( @@ -497,7 +498,14 @@ def test_search_parameter_creation_bls_default(self): concurrency = analyzer._search_parameters["add_sub"].get_parameter( "concurrency" ) - self.assertIsNone(concurrency) + self.assertEqual(ParameterUsage.RUNTIME, concurrency.usage) + self.assertEqual(ParameterCategory.EXPONENTIAL, concurrency.category) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MIN_CONCURRENCY), concurrency.min_range + ) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MAX_CONCURRENCY), concurrency.max_range + ) # instance_group instance_group = analyzer._search_parameters["add_sub"].get_parameter( @@ -514,10 +522,10 @@ def test_search_parameter_creation_bls_default(self): # ADD/SUB (composing models) # ===================================================================== - # The top level model of a BLS can only search instance count + # Composing models do not search concurrency # max_batch_size - max_batch_size = analyzer._search_parameters["add"].get_parameter( + max_batch_size = analyzer._composing_search_parameters["add"].get_parameter( "max_batch_size" ) self.assertEqual(ParameterUsage.MODEL, max_batch_size.usage) @@ -532,18 +540,13 @@ def test_search_parameter_creation_bls_default(self): ) # concurrency - concurrency = analyzer._search_parameters["sub"].get_parameter("concurrency") - self.assertEqual(ParameterUsage.RUNTIME, concurrency.usage) - self.assertEqual(ParameterCategory.EXPONENTIAL, concurrency.category) - self.assertEqual( - log2(default.DEFAULT_RUN_CONFIG_MIN_CONCURRENCY), concurrency.min_range - ) - self.assertEqual( - log2(default.DEFAULT_RUN_CONFIG_MAX_CONCURRENCY), concurrency.max_range + concurrency = analyzer._composing_search_parameters["sub"].get_parameter( + "concurrency" ) + self.assertIsNone(concurrency) # instance_group - instance_group = analyzer._search_parameters["add"].get_parameter( + instance_group = analyzer._composing_search_parameters["sub"].get_parameter( "instance_group" ) self.assertEqual(ParameterUsage.MODEL, instance_group.usage)