Skip to content

Commit

Permalink
Fixes based on PR
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed Jun 10, 2024
1 parent ab7e8d9 commit b43804f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
7 changes: 6 additions & 1 deletion model_analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
)

self._search_parameters: Dict[str, SearchParameters] = {}
self._composing_search_parameters: Dict[str, SearchParameters] = {}

def profile(
self, client: TritonClient, gpus: List[GPUDevice], mode: str, verbose: bool
Expand Down Expand Up @@ -119,6 +120,7 @@ def profile(
self._create_metrics_manager(client, gpus)
self._create_model_manager(client, gpus)
self._populate_search_parameters()
self._populate_composing_search_parameters()

if self._config.triton_launch_mode == "remote":
self._warn_if_other_models_loaded_on_remote_server(client)
Expand Down Expand Up @@ -428,9 +430,12 @@ def _populate_search_parameters(self):
model.model_config_parameters(),
is_bls_model=bool(self._config.bls_composing_models),
)

def _populate_composing_search_parameters(self):
for model in self._config.bls_composing_models:
self._search_parameters[model.model_name()] = SearchParameters(
self._composing_search_parameters[model.model_name()] = SearchParameters(
self._config,
model.parameters(),
model.model_config_parameters(),
is_composing_model=True,
)
4 changes: 3 additions & 1 deletion model_analyzer/config/generate/search_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def __init__(
parameters: Dict[str, Any] = {},
model_config_parameters: Dict[str, Any] = {},
is_bls_model: bool = False,
is_composing_model: bool = False,
):
self._config = config
self._parameters = parameters
self._model_config_parameters = model_config_parameters
self._search_parameters: Dict[str, SearchParameter] = {}
self._is_bls_model = is_bls_model
self._is_composing_model = is_composing_model

self._populate_search_parameters()

Expand Down Expand Up @@ -123,7 +125,7 @@ def _populate_search_parameters(self) -> None:
def _populate_parameters(self) -> None:
self._populate_batch_sizes()

if not self._is_bls_model:
if not self._is_composing_model:
self._populate_concurrency()
# TODO: Populate request rate - TMA-1903

Expand Down
29 changes: 16 additions & 13 deletions tests/test_search_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,11 @@ def test_search_parameter_creation_bls_default(self):

analyzer = Analyzer(config, MagicMock(), MagicMock(), MagicMock())
analyzer._populate_search_parameters()
analyzer._populate_composing_search_parameters()

# ADD_SUB
# =====================================================================
# The top level model of a BLS can only search instance count
# The top level model of a BLS does not search max batch size (always 1)

# max_batch_size
max_batch_size = analyzer._search_parameters["add_sub"].get_parameter(
Expand All @@ -497,7 +498,14 @@ def test_search_parameter_creation_bls_default(self):
concurrency = analyzer._search_parameters["add_sub"].get_parameter(
"concurrency"
)
self.assertIsNone(concurrency)
self.assertEqual(ParameterUsage.RUNTIME, concurrency.usage)
self.assertEqual(ParameterCategory.EXPONENTIAL, concurrency.category)
self.assertEqual(
log2(default.DEFAULT_RUN_CONFIG_MIN_CONCURRENCY), concurrency.min_range
)
self.assertEqual(
log2(default.DEFAULT_RUN_CONFIG_MAX_CONCURRENCY), concurrency.max_range
)

# instance_group
instance_group = analyzer._search_parameters["add_sub"].get_parameter(
Expand All @@ -514,10 +522,10 @@ def test_search_parameter_creation_bls_default(self):

# ADD/SUB (composing models)
# =====================================================================
# The top level model of a BLS can only search instance count
# Composing models do not search concurrency

# max_batch_size
max_batch_size = analyzer._search_parameters["add"].get_parameter(
max_batch_size = analyzer._composing_search_parameters["add"].get_parameter(
"max_batch_size"
)
self.assertEqual(ParameterUsage.MODEL, max_batch_size.usage)
Expand All @@ -532,18 +540,13 @@ def test_search_parameter_creation_bls_default(self):
)

# concurrency
concurrency = analyzer._search_parameters["sub"].get_parameter("concurrency")
self.assertEqual(ParameterUsage.RUNTIME, concurrency.usage)
self.assertEqual(ParameterCategory.EXPONENTIAL, concurrency.category)
self.assertEqual(
log2(default.DEFAULT_RUN_CONFIG_MIN_CONCURRENCY), concurrency.min_range
)
self.assertEqual(
log2(default.DEFAULT_RUN_CONFIG_MAX_CONCURRENCY), concurrency.max_range
concurrency = analyzer._composing_search_parameters["sub"].get_parameter(
"concurrency"
)
self.assertIsNone(concurrency)

# instance_group
instance_group = analyzer._search_parameters["add"].get_parameter(
instance_group = analyzer._composing_search_parameters["sub"].get_parameter(
"instance_group"
)
self.assertEqual(ParameterUsage.MODEL, instance_group.usage)
Expand Down

0 comments on commit b43804f

Please sign in to comment.