Skip to content

Commit

Permalink
Adding BLS composing models to SearchParameters
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed Jun 10, 2024
1 parent 4e96d21 commit ab7e8d9
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 4 deletions.
11 changes: 10 additions & 1 deletion model_analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,5 +423,14 @@ def _warn_if_other_models_loaded_on_remote_server(self, client):
def _populate_search_parameters(self):
for model in self._config.profile_models:
self._search_parameters[model.model_name()] = SearchParameters(
self._config, model.parameters(), model.model_config_parameters()
self._config,
model.parameters(),
model.model_config_parameters(),
is_bls_model=bool(self._config.bls_composing_models),
)
for model in self._config.bls_composing_models:
self._search_parameters[model.model_name()] = SearchParameters(
self._config,
model.parameters(),
model.model_config_parameters(),
)
7 changes: 7 additions & 0 deletions model_analyzer/config/generate/optuna_run_config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def __init__(
# TODO: TMA-1927: Add support for multi-model
self._search_parameters = search_parameters[models[0].model_name()]

# TODO: need to add in ensemble support
self._composing_search_parameters = {}
for composing_model in self._config.bls_composing_models:
self._composing_search_parameters[
composing_model.model_name()
] = search_parameters[composing_model.model_name()]

self._model_variant_name_manager = model_variant_name_manager

self._triton_env = BruteRunConfigGenerator.determine_triton_server_env(models)
Expand Down
12 changes: 9 additions & 3 deletions model_analyzer/config/generate/search_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ def __init__(
config: ConfigCommandProfile = ConfigCommandProfile(),
parameters: Dict[str, Any] = {},
model_config_parameters: Dict[str, Any] = {},
is_bls_model: bool = False,
):
self._config = config
self._parameters = parameters
self._model_config_parameters = model_config_parameters
self._search_parameters: Dict[str, SearchParameter] = {}
self._is_bls_model = is_bls_model

self._populate_search_parameters()

Expand Down Expand Up @@ -120,11 +122,15 @@ def _populate_search_parameters(self) -> None:

def _populate_parameters(self) -> None:
self._populate_batch_sizes()
self._populate_concurrency()
# TODO: Populate request rate - TMA-1903

if not self._is_bls_model:
self._populate_concurrency()
# TODO: Populate request rate - TMA-1903

def _populate_model_config_parameters(self) -> None:
self._populate_max_batch_size()
if not self._is_bls_model:
self._populate_max_batch_size()

self._populate_instance_group()
self._populate_max_queue_delay_microseconds()

Expand Down
6 changes: 6 additions & 0 deletions model_analyzer/config/input/config_command_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse
import logging
import os
from copy import deepcopy

import numba.cuda
import psutil
Expand Down Expand Up @@ -1691,6 +1692,11 @@ def _autofill_values(self):
new_model["model_config_parameters"] = model.model_config_parameters()

new_profile_models[model.model_name()] = new_model

# deepcopy is necessary, else it gets overwritten when updating profile_models
self._fields["bls_composing_models"] = deepcopy(
self._fields["bls_composing_models"]
)
self._fields["profile_models"].set_value(new_profile_models)

def _using_request_rate(self) -> bool:
Expand Down
98 changes: 98 additions & 0 deletions tests/test_search_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,104 @@ def test_total_possible_configurations(self):
# batch_sizes (8) * instance group (8) * concurrency (11) * size (3)
self.assertEqual(8 * 8 * 11 * 3, total_num_of_possible_configurations)

def test_search_parameter_creation_bls_default(self):
"""
Test that search parameters are correctly created in default BLS optuna case
"""

args = [
"model-analyzer",
"profile",
"--model-repository",
"cli-repository",
"-f",
"path-to-config-file",
"--run-config-search-mode",
"optuna",
]

yaml_content = """
profile_models: add_sub
bls_composing_models: add,sub
"""

config = TestConfig()._evaluate_config(args=args, yaml_content=yaml_content)

analyzer = Analyzer(config, MagicMock(), MagicMock(), MagicMock())
analyzer._populate_search_parameters()

# ADD_SUB
# =====================================================================
# The top level model of a BLS can only search instance count

# max_batch_size
max_batch_size = analyzer._search_parameters["add_sub"].get_parameter(
"max_batch_size"
)
self.assertIsNone(max_batch_size)

# concurrency
concurrency = analyzer._search_parameters["add_sub"].get_parameter(
"concurrency"
)
self.assertIsNone(concurrency)

# instance_group
instance_group = analyzer._search_parameters["add_sub"].get_parameter(
"instance_group"
)
self.assertEqual(ParameterUsage.MODEL, instance_group.usage)
self.assertEqual(ParameterCategory.INTEGER, instance_group.category)
self.assertEqual(
default.DEFAULT_RUN_CONFIG_MIN_INSTANCE_COUNT, instance_group.min_range
)
self.assertEqual(
default.DEFAULT_RUN_CONFIG_MAX_INSTANCE_COUNT, instance_group.max_range
)

# ADD/SUB (composing models)
# =====================================================================
# The top level model of a BLS can only search instance count

# max_batch_size
max_batch_size = analyzer._search_parameters["add"].get_parameter(
"max_batch_size"
)
self.assertEqual(ParameterUsage.MODEL, max_batch_size.usage)
self.assertEqual(ParameterCategory.EXPONENTIAL, max_batch_size.category)
self.assertEqual(
log2(default.DEFAULT_RUN_CONFIG_MIN_MODEL_BATCH_SIZE),
max_batch_size.min_range,
)
self.assertEqual(
log2(default.DEFAULT_RUN_CONFIG_MAX_MODEL_BATCH_SIZE),
max_batch_size.max_range,
)

# concurrency
concurrency = analyzer._search_parameters["sub"].get_parameter("concurrency")
self.assertEqual(ParameterUsage.RUNTIME, concurrency.usage)
self.assertEqual(ParameterCategory.EXPONENTIAL, concurrency.category)
self.assertEqual(
log2(default.DEFAULT_RUN_CONFIG_MIN_CONCURRENCY), concurrency.min_range
)
self.assertEqual(
log2(default.DEFAULT_RUN_CONFIG_MAX_CONCURRENCY), concurrency.max_range
)

# instance_group
instance_group = analyzer._search_parameters["add"].get_parameter(
"instance_group"
)
self.assertEqual(ParameterUsage.MODEL, instance_group.usage)
self.assertEqual(ParameterCategory.INTEGER, instance_group.category)
self.assertEqual(
default.DEFAULT_RUN_CONFIG_MIN_INSTANCE_COUNT, instance_group.min_range
)
self.assertEqual(
default.DEFAULT_RUN_CONFIG_MAX_INSTANCE_COUNT, instance_group.max_range
)


if __name__ == "__main__":
unittest.main()

0 comments on commit ab7e8d9

Please sign in to comment.