diff --git a/model_analyzer/analyzer.py b/model_analyzer/analyzer.py index 0f4cbdab9..25131396e 100755 --- a/model_analyzer/analyzer.py +++ b/model_analyzer/analyzer.py @@ -17,12 +17,13 @@ import logging import sys from copy import deepcopy -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from model_analyzer.cli.cli import CLI from model_analyzer.config.generate.base_model_config_generator import ( BaseModelConfigGenerator, ) +from model_analyzer.config.generate.search_parameters import SearchParameters from model_analyzer.constants import LOGGER_NAME, PA_ERROR_LOG_FILENAME from model_analyzer.state.analyzer_state_manager import AnalyzerStateManager from model_analyzer.triton.server.server import TritonServer @@ -82,6 +83,8 @@ def __init__( constraint_manager=self._constraint_manager, ) + self._search_parameters: Dict[str, SearchParameters] = {} + def profile( self, client: TritonClient, gpus: List[GPUDevice], mode: str, verbose: bool ) -> None: @@ -115,6 +118,7 @@ def profile( self._create_metrics_manager(client, gpus) self._create_model_manager(client, gpus) + self._populate_search_parameters() if self._config.triton_launch_mode == "remote": self._warn_if_other_models_loaded_on_remote_server(client) @@ -414,3 +418,9 @@ def _warn_if_other_models_loaded_on_remote_server(self, client): f"A model not being profiled ({model_name}) is loaded on the remote Tritonserver. " "This could impact the profile results." ) + + def _populate_search_parameters(self): + for model in self._config.profile_models: + self._search_parameters[model.model_name()] = SearchParameters( + self._config, model.parameters(), model.model_config_parameters() + ) diff --git a/model_analyzer/config/generate/config_parameters.py b/model_analyzer/config/generate/config_parameters.py deleted file mode 100755 index 8ab3b40c1..000000000 --- a/model_analyzer/config/generate/config_parameters.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, List, Optional, Tuple - -from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException - -from .config_parameter import ConfigParameter, ParameterCategory, ParameterType - - -class ConfigParameters: - """ - Contains information about all configuration parameters the user wants to search - """ - - def __init__(self): - self._parameters: Dict[str, ConfigParameter] = {} - - def add_parameter( - self, - name: str, - ptype: ParameterType, - category: ParameterCategory, - min_range: Optional[int] = None, - max_range: Optional[int] = None, - enumerated_list: List[Any] = [], - ) -> None: - self._check_for_illegal_input(category, min_range, max_range, enumerated_list) - - self._parameters[name] = ConfigParameter( - ptype, category, min_range, max_range, enumerated_list - ) - - def get_parameter(self, name: str) -> ConfigParameter: - return self._parameters[name] - - def get_type(self, name: str) -> ParameterType: - return self._parameters[name].ptype - - def get_category(self, name: str) -> ParameterCategory: - return self._parameters[name].category - - def get_range(self, name: str) -> Tuple[int, int]: - return (self._parameters[name].min_range, self._parameters[name].max_range) - - def get_list(self, name: str) -> List[Any]: - return self._parameters[name].enumerated_list - - def _check_for_illegal_input( - self, - category: ParameterCategory, - min_range: Optional[int], - max_range: Optional[int], - enumerated_list: List[Any], - ) -> None: - if category is ParameterCategory.LIST: - self._check_for_illegal_list_input(min_range, max_range, enumerated_list) - else: - if min_range is None or max_range is None: - raise TritonModelAnalyzerException( - f"Both min_range and max_range must be specified" - ) - - if min_range and max_range: - if min_range > max_range: - raise TritonModelAnalyzerException( - f"min_range cannot be larger than max_range" - ) - - def _check_for_illegal_list_input( - self, - min_range: Optional[int], - max_range: Optional[int], - enumerated_list: List[Any], - ) -> None: - if not enumerated_list: - raise TritonModelAnalyzerException( - f"enumerated_list must be specified for a ParameterCategory.LIST" - ) - elif min_range is not None: - raise TritonModelAnalyzerException( - f"min_range cannot be specified for a ParameterCategory.LIST" - ) - elif max_range is not None: - raise TritonModelAnalyzerException( - f"max_range cannot be specified for a ParameterCategory.LIST" - ) diff --git a/model_analyzer/config/generate/config_parameter.py b/model_analyzer/config/generate/search_parameter.py similarity index 84% rename from model_analyzer/config/generate/config_parameter.py rename to model_analyzer/config/generate/search_parameter.py index 52c7db1d2..0c342cc56 100755 --- a/model_analyzer/config/generate/config_parameter.py +++ b/model_analyzer/config/generate/search_parameter.py @@ -19,7 +19,7 @@ from typing import Any, List, Optional -class ParameterType(Enum): +class ParameterUsage(Enum): MODEL = auto() RUNTIME = auto() BUILD = auto() @@ -32,17 +32,17 @@ class ParameterCategory(Enum): @dataclass -class ConfigParameter: +class SearchParameter: """ - A dataclass that holds information about a configuration parameter + A dataclass that holds information about a configuration's search parameter """ - ptype: ParameterType + usage: ParameterUsage category: ParameterCategory + # This is only applicable to LIST category + enumerated_list: Optional[List[Any]] = None + # These are only applicable to INTEGER and EXPONENTIAL categories min_range: Optional[int] = None max_range: Optional[int] = None - - # This is only applicable to LIST category - enumerated_list: List[Any] = [] diff --git a/model_analyzer/config/generate/search_parameters.py b/model_analyzer/config/generate/search_parameters.py new file mode 100755 index 000000000..080df444c --- /dev/null +++ b/model_analyzer/config/generate/search_parameters.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 + +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from math import log2 +from typing import Any, Dict, List, Optional, Tuple + +from model_analyzer.config.input.config_command_profile import ConfigCommandProfile +from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException + +from .search_parameter import ParameterCategory, ParameterUsage, SearchParameter + + +class SearchParameters: + """ + Contains information about all configuration parameters the user wants to search + """ + + # These map to the run-config-search fields + # See github.com/triton-inference-server/model_analyzer/blob/main/docs/config.md + exponential_rcs_parameters = ["batch_sizes", "concurrency"] + linear_rcs_parameters = ["instance_group"] + + model_parameters = ["batch_sizes", "instance_group", "max_queue_delay_microseconds"] + runtime_parameters = ["concurrency"] + + def __init__( + self, + config: ConfigCommandProfile = ConfigCommandProfile(), + parameters: Dict[str, Any] = {}, + model_config_parameters: Dict[str, Any] = {}, + ): + self._config = config + self._parameters = parameters + self._model_config_parameters = model_config_parameters + self._search_parameters: Dict[str, SearchParameter] = {} + + self._populate_search_parameters() + + def get_parameter(self, name: str) -> SearchParameter: + return self._search_parameters[name] + + def get_type(self, name: str) -> ParameterUsage: + return self._search_parameters[name].usage + + def get_category(self, name: str) -> ParameterCategory: + return self._search_parameters[name].category + + def get_range(self, name: str) -> Tuple[Optional[int], Optional[int]]: + return ( + self._search_parameters[name].min_range, + self._search_parameters[name].max_range, + ) + + def get_list(self, name: str) -> Optional[List[Any]]: + return self._search_parameters[name].enumerated_list + + def _populate_search_parameters(self) -> None: + if self._parameters: + self._populate_parameters() + + self._populate_model_config_parameters() + + def _populate_parameters(self) -> None: + self._populate_batch_sizes() + self._populate_concurrency() + # TODO: Populate request rate - TMA-1903 + + def _populate_model_config_parameters(self) -> None: + self._populate_instance_group() + self._populate_max_queue_delay_microseconds() + + def _populate_batch_sizes(self) -> None: + if self._parameters["batch_sizes"]: + self._populate_list_parameter( + parameter_name="batch_sizes", + parameter_list=self._parameters["batch_sizes"], + ) + else: + self._populate_rcs_parameter( + parameter_name="batch_sizes", + rcs_parameter_min_value=self._config.run_config_search_min_model_batch_size, + rcs_parameter_max_value=self._config.run_config_search_max_model_batch_size, + ) + + def _populate_concurrency(self) -> None: + if self._parameters["concurrency"]: + self._populate_list_parameter( + parameter_name="concurrency", + parameter_list=self._parameters["concurrency"], + ) + else: + self._populate_rcs_parameter( + parameter_name="concurrency", + rcs_parameter_min_value=self._config.run_config_search_min_concurrency, + rcs_parameter_max_value=self._config.run_config_search_max_concurrency, + ) + + def _populate_instance_group(self) -> None: + # Example config format: + # + # model_config_parameters: + # instance_group: + # - kind: KIND_GPU + # count: [1, 2, 3, 4] + + # Need to populate instance_group based on RCS min/max values + # even if no model config parameters are present + if not self._model_config_parameters: + self._populate_rcs_parameter( + parameter_name="instance_group", + rcs_parameter_min_value=self._config.run_config_search_min_instance_count, + rcs_parameter_max_value=self._config.run_config_search_max_instance_count, + ) + elif "instance_group" in self._model_config_parameters.keys(): + parameter_list = self._model_config_parameters["instance_group"][0][0][ + "count" + ] + + self._populate_list_parameter( + parameter_name="instance_group", + parameter_list=parameter_list, + ) + else: + self._populate_rcs_parameter( + parameter_name="instance_group", + rcs_parameter_min_value=self._config.run_config_search_min_instance_count, + rcs_parameter_max_value=self._config.run_config_search_max_instance_count, + ) + + def _populate_max_queue_delay_microseconds(self) -> None: + # Example format + # + # model_config_parameters: + # dynamic_batching: + # max_queue_delay_microseconds: [100, 200, 300] + + # There is no RCS field for max_queue_delay_microseconds + if self._is_max_queue_delay_in_model_config_parameters(): + self._populate_list_parameter( + parameter_name="max_queue_delay_microseconds", + parameter_list=self._model_config_parameters["dynamic_batching"][0][ + "max_queue_delay_microseconds" + ], + ) + + def _is_max_queue_delay_in_model_config_parameters(self) -> bool: + if self._model_config_parameters: + max_queue_delay_present = ( + "dynamic_batching" in self._model_config_parameters.keys() + and ( + "max_queue_delay_microseconds" + in self._model_config_parameters["dynamic_batching"][0] + ) + ) + else: + max_queue_delay_present = False + + return max_queue_delay_present + + def _populate_list_parameter( + self, + parameter_name: str, + parameter_list: List[int], + ) -> None: + usage = self._determine_parameter_usage(parameter_name) + + self._add_search_parameter( + name=parameter_name, + usage=usage, + category=ParameterCategory.LIST, + enumerated_list=parameter_list, + ) + + def _populate_rcs_parameter( + self, + parameter_name: str, + rcs_parameter_min_value: int, + rcs_parameter_max_value: int, + ) -> None: + usage = self._determine_parameter_usage(parameter_name) + category = self._determine_parameter_category(parameter_name) + + if category == ParameterCategory.EXPONENTIAL: + min_range = int(log2(rcs_parameter_min_value)) # type: ignore + max_range = int(log2(rcs_parameter_max_value)) # type: ignore + else: + min_range = rcs_parameter_min_value # type: ignore + max_range = rcs_parameter_max_value # type: ignore + + self._add_search_parameter( + name=parameter_name, + usage=usage, + category=category, + min_range=min_range, + max_range=max_range, + ) + + def _determine_parameter_category(self, name: str) -> ParameterCategory: + if name in SearchParameters.exponential_rcs_parameters: + category = ParameterCategory.EXPONENTIAL + elif name in SearchParameters.linear_rcs_parameters: + category = ParameterCategory.INTEGER + else: + TritonModelAnalyzerException(f"ParameterCategory not found for {name}") + + return category + + def _determine_parameter_usage(self, name: str) -> ParameterUsage: + if name in SearchParameters.model_parameters: + usage = ParameterUsage.MODEL + elif name in SearchParameters.runtime_parameters: + usage = ParameterUsage.RUNTIME + else: + TritonModelAnalyzerException(f"ParameterUsage not found for {name}") + + return usage + + def _add_search_parameter( + self, + name: str, + usage: ParameterUsage, + category: ParameterCategory, + min_range: Optional[int] = None, + max_range: Optional[int] = None, + enumerated_list: List[Any] = [], + ) -> None: + self._check_for_illegal_input(category, min_range, max_range, enumerated_list) + + self._search_parameters[name] = SearchParameter( + usage=usage, + category=category, + enumerated_list=enumerated_list, + min_range=min_range, + max_range=max_range, + ) + + def _check_for_illegal_input( + self, + category: ParameterCategory, + min_range: Optional[int], + max_range: Optional[int], + enumerated_list: List[Any], + ) -> None: + if category is ParameterCategory.LIST: + self._check_for_illegal_list_input(min_range, max_range, enumerated_list) + else: + if min_range is None or max_range is None: + raise TritonModelAnalyzerException( + f"Both min_range and max_range must be specified" + ) + + if min_range and max_range: + if min_range > max_range: + raise TritonModelAnalyzerException( + f"min_range cannot be larger than max_range" + ) + + def _check_for_illegal_list_input( + self, + min_range: Optional[int], + max_range: Optional[int], + enumerated_list: List[Any], + ) -> None: + if not enumerated_list: + raise TritonModelAnalyzerException( + f"enumerated_list must be specified for a ParameterCategory.LIST" + ) + elif min_range is not None: + raise TritonModelAnalyzerException( + f"min_range cannot be specified for a ParameterCategory.LIST" + ) + elif max_range is not None: + raise TritonModelAnalyzerException( + f"max_range cannot be specified for a ParameterCategory.LIST" + ) diff --git a/model_analyzer/config/input/config_command.py b/model_analyzer/config/input/config_command.py index 23e4fc484..97990901c 100755 --- a/model_analyzer/config/input/config_command.py +++ b/model_analyzer/config/input/config_command.py @@ -289,10 +289,10 @@ def _check_quick_search_model_config_parameters_combinations(self) -> None: if not "profile_models" in config: return - if config["run_config_search_mode"] != "quick": + if config["run_config_search_mode"].value() != "quick": return - profile_models = config()["profile_models"].value() + profile_models = config["profile_models"].value() for model in profile_models: model_config_params = deepcopy(model.model_config_parameters()) if model_config_params: diff --git a/model_analyzer/config/input/config_command_profile.py b/model_analyzer/config/input/config_command_profile.py index f639d1c72..f15d26004 100755 --- a/model_analyzer/config/input/config_command_profile.py +++ b/model_analyzer/config/input/config_command_profile.py @@ -1529,11 +1529,21 @@ def _autofill_values(self): # Run parameters if not model.parameters(): - new_model["parameters"] = { - "batch_sizes": self.batch_sizes, - "concurrency": self.concurrency, - "request_rate": self.request_rate, - } + if self.run_config_search_mode != "optuna": + new_model["parameters"] = { + "batch_sizes": self.batch_sizes, + "concurrency": self.concurrency, + "request_rate": self.request_rate, + } + else: + if self._fields["batch_sizes"].is_set_by_user(): + new_model["parameters"] = {"batch_sizes": self.batch_sizes} + else: + new_model["parameters"] = {"batch_sizes": []} + + new_model["parameters"]["concurrency"] = self.concurrency + new_model["parameters"]["request_rate"] = self.request_rate + else: new_model["parameters"] = {} if "batch_sizes" in model.parameters(): @@ -1541,7 +1551,12 @@ def _autofill_values(self): {"batch_sizes": model.parameters()["batch_sizes"]} ) else: - new_model["parameters"].update({"batch_sizes": self.batch_sizes}) + if self.run_config_search_mode != "optuna": + new_model["parameters"].update( + {"batch_sizes": self.batch_sizes} + ) + else: + new_model["parameters"].update({"batch_sizes": []}) if "concurrency" in model.parameters(): new_model["parameters"].update( diff --git a/tests/test_config_parameters.py b/tests/test_config_parameters.py deleted file mode 100755 index a87c2e0f3..000000000 --- a/tests/test_config_parameters.py +++ /dev/null @@ -1,156 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from model_analyzer.config.generate.config_parameters import ( - ConfigParameters, - ParameterCategory, - ParameterType, -) -from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException - -from .common import test_result_collector as trc - - -class TestConfigParameters(trc.TestResultCollector): - def setUp(self): - self.config_parameters = ConfigParameters() - - self.config_parameters.add_parameter( - name="concurrency", - ptype=ParameterType.RUNTIME, - category=ParameterCategory.EXPONENTIAL, - min_range=0, - max_range=10, - ) - - self.config_parameters.add_parameter( - name="instance_count", - ptype=ParameterType.MODEL, - category=ParameterCategory.INTEGER, - min_range=1, - max_range=8, - ) - - self.config_parameters.add_parameter( - name="size", - ptype=ParameterType.BUILD, - category=ParameterCategory.LIST, - enumerated_list=["FP8", "FP16", "FP32"], - ) - - def test_exponential_parameter(self): - """ - Test exponential parameter, accessing dataclass directly - """ - - parameter = self.config_parameters.get_parameter("concurrency") - - self.assertEqual(ParameterType.RUNTIME, parameter.ptype) - self.assertEqual(ParameterCategory.EXPONENTIAL, parameter.category) - self.assertEqual(0, parameter.min_range) - self.assertEqual(10, parameter.max_range) - - def test_integer_parameter(self): - """ - Test integer parameter, using accessor methods - """ - - self.assertEqual( - ParameterType.MODEL, - self.config_parameters.get_type("instance_count"), - ) - self.assertEqual( - ParameterCategory.INTEGER, - self.config_parameters.get_category("instance_count"), - ) - self.assertEqual((1, 8), self.config_parameters.get_range("instance_count")) - - def test_list_parameter(self): - """ - Test list parameter, using accessor methods - """ - - self.assertEqual( - ParameterType.BUILD, - self.config_parameters.get_type("size"), - ) - self.assertEqual( - ParameterCategory.LIST, - self.config_parameters.get_category("size"), - ) - self.assertEqual( - ["FP8", "FP16", "FP32"], self.config_parameters.get_list("size") - ) - - def test_illegal_inputs(self): - """ - Check that an exception is raised for illegal input combos - """ - with self.assertRaises(TritonModelAnalyzerException): - self.config_parameters.add_parameter( - name="concurrency", - ptype=ParameterType.RUNTIME, - category=ParameterCategory.EXPONENTIAL, - max_range=10, - ) - - with self.assertRaises(TritonModelAnalyzerException): - self.config_parameters.add_parameter( - name="concurrency", - ptype=ParameterType.RUNTIME, - category=ParameterCategory.EXPONENTIAL, - min_range=0, - ) - - with self.assertRaises(TritonModelAnalyzerException): - self.config_parameters.add_parameter( - name="concurrency", - ptype=ParameterType.RUNTIME, - category=ParameterCategory.EXPONENTIAL, - min_range=10, - max_range=9, - ) - - with self.assertRaises(TritonModelAnalyzerException): - self.config_parameters.add_parameter( - name="size", - ptype=ParameterType.BUILD, - category=ParameterCategory.LIST, - ) - - with self.assertRaises(TritonModelAnalyzerException): - self.config_parameters.add_parameter( - name="size", - ptype=ParameterType.BUILD, - category=ParameterCategory.LIST, - enumerated_list=["FP8", "FP16", "FP32"], - min_range=0, - ) - - with self.assertRaises(TritonModelAnalyzerException): - self.config_parameters.add_parameter( - name="size", - ptype=ParameterType.BUILD, - category=ParameterCategory.LIST, - enumerated_list=["FP8", "FP16", "FP32"], - max_range=10, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_search_parameters.py b/tests/test_search_parameters.py new file mode 100755 index 000000000..64c154912 --- /dev/null +++ b/tests/test_search_parameters.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python3 + +# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from math import log2 +from unittest.mock import MagicMock, patch + +import model_analyzer.config.input.config_defaults as default +from model_analyzer.analyzer import Analyzer +from model_analyzer.config.generate.search_parameters import ( + ParameterCategory, + ParameterUsage, + SearchParameters, +) +from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException +from tests.test_config import TestConfig + +from .common import test_result_collector as trc +from .mocks.mock_os import MockOSMethods + + +class TestSearchParameters(trc.TestResultCollector): + def setUp(self): + # Mock path validation + self.mock_os = MockOSMethods( + mock_paths=["model_analyzer.config.input.config_utils"] + ) + self.mock_os.start() + + args = [ + "model-analyzer", + "profile", + "--model-repository", + "cli-repository", + "-f", + "path-to-config-file", + "--run-config-search-mode", + "optuna", + ] + + yaml_content = """ + profile_models: add_sub + """ + + config = TestConfig()._evaluate_config(args=args, yaml_content=yaml_content) + + self.search_parameters = SearchParameters(config) + + self.search_parameters._add_search_parameter( + name="concurrency", + usage=ParameterUsage.RUNTIME, + category=ParameterCategory.EXPONENTIAL, + min_range=0, + max_range=10, + ) + + self.search_parameters._add_search_parameter( + name="instance_count", + usage=ParameterUsage.MODEL, + category=ParameterCategory.INTEGER, + min_range=1, + max_range=8, + ) + + self.search_parameters._add_search_parameter( + name="size", + usage=ParameterUsage.BUILD, + category=ParameterCategory.LIST, + enumerated_list=["FP8", "FP16", "FP32"], + ) + + def tearDown(self): + self.mock_os.stop() + patch.stopall() + + def test_exponential_parameter(self): + """ + Test exponential parameter, accessing dataclass directly + """ + + parameter = self.search_parameters.get_parameter("concurrency") + + self.assertEqual(ParameterUsage.RUNTIME, parameter.usage) + self.assertEqual(ParameterCategory.EXPONENTIAL, parameter.category) + self.assertEqual(0, parameter.min_range) + self.assertEqual(10, parameter.max_range) + + def test_integer_parameter(self): + """ + Test integer parameter, using accessor methods + """ + + self.assertEqual( + ParameterUsage.MODEL, + self.search_parameters.get_type("instance_count"), + ) + self.assertEqual( + ParameterCategory.INTEGER, + self.search_parameters.get_category("instance_count"), + ) + self.assertEqual((1, 8), self.search_parameters.get_range("instance_count")) + + def test_list_parameter(self): + """ + Test list parameter, using accessor methods + """ + + self.assertEqual( + ParameterUsage.BUILD, + self.search_parameters.get_type("size"), + ) + self.assertEqual( + ParameterCategory.LIST, + self.search_parameters.get_category("size"), + ) + self.assertEqual( + ["FP8", "FP16", "FP32"], self.search_parameters.get_list("size") + ) + + def test_illegal_inputs(self): + """ + Check that an exception is raised for illegal input combos + """ + with self.assertRaises(TritonModelAnalyzerException): + self.search_parameters._add_search_parameter( + name="concurrency", + usage=ParameterUsage.RUNTIME, + category=ParameterCategory.EXPONENTIAL, + max_range=10, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.search_parameters._add_search_parameter( + name="concurrency", + usage=ParameterUsage.RUNTIME, + category=ParameterCategory.EXPONENTIAL, + min_range=0, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.search_parameters._add_search_parameter( + name="concurrency", + usage=ParameterUsage.RUNTIME, + category=ParameterCategory.EXPONENTIAL, + min_range=10, + max_range=9, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.search_parameters._add_search_parameter( + name="size", + usage=ParameterUsage.BUILD, + category=ParameterCategory.LIST, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.search_parameters._add_search_parameter( + name="size", + usage=ParameterUsage.BUILD, + category=ParameterCategory.LIST, + enumerated_list=["FP8", "FP16", "FP32"], + min_range=0, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.search_parameters._add_search_parameter( + name="size", + usage=ParameterUsage.BUILD, + category=ParameterCategory.LIST, + enumerated_list=["FP8", "FP16", "FP32"], + max_range=10, + ) + + def test_search_parameter_creation_default(self): + """ + Test that search parameters are correctly created in default optuna case + """ + + args = [ + "model-analyzer", + "profile", + "--model-repository", + "cli-repository", + "-f", + "path-to-config-file", + "--run-config-search-mode", + "optuna", + ] + + yaml_content = """ + profile_models: add_sub + """ + + config = TestConfig()._evaluate_config(args=args, yaml_content=yaml_content) + + analyzer = Analyzer(config, MagicMock(), MagicMock(), MagicMock()) + analyzer._populate_search_parameters() + + # batch_sizes + batch_sizes = analyzer._search_parameters["add_sub"].get_parameter( + "batch_sizes" + ) + self.assertEqual(ParameterUsage.MODEL, batch_sizes.usage) + self.assertEqual(ParameterCategory.EXPONENTIAL, batch_sizes.category) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MIN_MODEL_BATCH_SIZE), batch_sizes.min_range + ) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MAX_MODEL_BATCH_SIZE), batch_sizes.max_range + ) + + # concurrency + concurrency = analyzer._search_parameters["add_sub"].get_parameter( + "concurrency" + ) + self.assertEqual(ParameterUsage.RUNTIME, concurrency.usage) + self.assertEqual(ParameterCategory.EXPONENTIAL, concurrency.category) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MIN_CONCURRENCY), concurrency.min_range + ) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MAX_CONCURRENCY), concurrency.max_range + ) + + # instance_group + instance_group = analyzer._search_parameters["add_sub"].get_parameter( + "instance_group" + ) + self.assertEqual(ParameterUsage.MODEL, instance_group.usage) + self.assertEqual(ParameterCategory.INTEGER, instance_group.category) + self.assertEqual( + default.DEFAULT_RUN_CONFIG_MIN_INSTANCE_COUNT, instance_group.min_range + ) + self.assertEqual( + default.DEFAULT_RUN_CONFIG_MAX_INSTANCE_COUNT, instance_group.max_range + ) + + def test_search_parameter_creation_multi_model_non_default(self): + """ + Test that search parameters are correctly created in + a multi-model non-default optuna case + """ + + args = [ + "model-analyzer", + "profile", + "--model-repository", + "cli-repository", + "-f", + "path-to-config-file", + "--run-config-search-mode", + "optuna", + ] + + yaml_content = """ + run_config_search_mode: optuna + profile_models: + add_sub: + parameters: + batch_sizes: [16, 32, 64] + model_config_parameters: + dynamic_batching: + max_queue_delay_microseconds: [100, 200, 300] + instance_group: + - kind: KIND_GPU + count: [1, 2, 3, 4] + mult_div: + parameters: + concurrency: [1, 8, 64, 256] + """ + + config = TestConfig()._evaluate_config(args, yaml_content) + + analyzer = Analyzer(config, MagicMock(), MagicMock(), MagicMock()) + analyzer._populate_search_parameters() + + # =================================================================== + # ADD_SUB + # =================================================================== + + # batch_sizes + # =================================================================== + batch_sizes = analyzer._search_parameters["add_sub"].get_parameter( + "batch_sizes" + ) + self.assertEqual(ParameterUsage.MODEL, batch_sizes.usage) + self.assertEqual(ParameterCategory.LIST, batch_sizes.category) + self.assertEqual([16, 32, 64], batch_sizes.enumerated_list) + + # concurrency + # =================================================================== + concurrency = analyzer._search_parameters["add_sub"].get_parameter( + "concurrency" + ) + self.assertEqual(ParameterUsage.RUNTIME, concurrency.usage) + self.assertEqual(ParameterCategory.EXPONENTIAL, concurrency.category) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MIN_CONCURRENCY), concurrency.min_range + ) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MAX_CONCURRENCY), concurrency.max_range + ) + + # instance_group + # =================================================================== + instance_group = analyzer._search_parameters["add_sub"].get_parameter( + "instance_group" + ) + self.assertEqual(ParameterUsage.MODEL, instance_group.usage) + self.assertEqual(ParameterCategory.LIST, instance_group.category) + self.assertEqual([1, 2, 3, 4], instance_group.enumerated_list) + + instance_group = analyzer._search_parameters["add_sub"].get_parameter( + "max_queue_delay_microseconds" + ) + self.assertEqual(ParameterUsage.MODEL, instance_group.usage) + self.assertEqual(ParameterCategory.LIST, instance_group.category) + self.assertEqual([100, 200, 300], instance_group.enumerated_list) + + # =================================================================== + # MULT_DIV + # =================================================================== + + # batch_sizes + # =================================================================== + batch_sizes = analyzer._search_parameters["mult_div"].get_parameter( + "batch_sizes" + ) + self.assertEqual(ParameterUsage.MODEL, batch_sizes.usage) + self.assertEqual(ParameterCategory.EXPONENTIAL, batch_sizes.category) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MIN_MODEL_BATCH_SIZE), batch_sizes.min_range + ) + self.assertEqual( + log2(default.DEFAULT_RUN_CONFIG_MAX_MODEL_BATCH_SIZE), batch_sizes.max_range + ) + + # concurrency + # =================================================================== + concurrency = analyzer._search_parameters["mult_div"].get_parameter( + "concurrency" + ) + self.assertEqual(ParameterUsage.RUNTIME, concurrency.usage) + self.assertEqual(ParameterCategory.LIST, concurrency.category) + self.assertEqual([1, 8, 64, 256], concurrency.enumerated_list) + + # instance_group + # =================================================================== + instance_group = analyzer._search_parameters["mult_div"].get_parameter( + "instance_group" + ) + self.assertEqual(ParameterUsage.MODEL, instance_group.usage) + self.assertEqual(ParameterCategory.INTEGER, instance_group.category) + self.assertEqual( + default.DEFAULT_RUN_CONFIG_MIN_INSTANCE_COUNT, instance_group.min_range + ) + self.assertEqual( + default.DEFAULT_RUN_CONFIG_MAX_INSTANCE_COUNT, instance_group.max_range + ) + + +if __name__ == "__main__": + unittest.main()