From 8cc8a2cc1a4d38150e166d3a75f1c50d1356f7d2 Mon Sep 17 00:00:00 2001 From: Brian Raf Date: Mon, 29 Apr 2024 17:30:36 +0000 Subject: [PATCH] Initial code for ConfigParameters class --- .../config/generate/config_parameter.py | 48 ++++++ .../config/generate/config_parameters.py | 100 +++++++++++ tests/test_config_parameters.py | 157 ++++++++++++++++++ 3 files changed, 305 insertions(+) create mode 100755 model_analyzer/config/generate/config_parameter.py create mode 100755 model_analyzer/config/generate/config_parameters.py create mode 100755 tests/test_config_parameters.py diff --git a/model_analyzer/config/generate/config_parameter.py b/model_analyzer/config/generate/config_parameter.py new file mode 100755 index 000000000..52c7db1d2 --- /dev/null +++ b/model_analyzer/config/generate/config_parameter.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum, auto +from typing import Any, List, Optional + + +class ParameterType(Enum): + MODEL = auto() + RUNTIME = auto() + BUILD = auto() + + +class ParameterCategory(Enum): + INTEGER = auto() + EXPONENTIAL = auto() + LIST = auto() + + +@dataclass +class ConfigParameter: + """ + A dataclass that holds information about a configuration parameter + """ + + ptype: ParameterType + category: ParameterCategory + + # These are only applicable to INTEGER and EXPONENTIAL categories + min_range: Optional[int] = None + max_range: Optional[int] = None + + # This is only applicable to LIST category + enumerated_list: List[Any] = [] diff --git a/model_analyzer/config/generate/config_parameters.py b/model_analyzer/config/generate/config_parameters.py new file mode 100755 index 000000000..aef9c86ba --- /dev/null +++ b/model_analyzer/config/generate/config_parameters.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 + +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple + +from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException + +from .config_parameter import ConfigParameter, ParameterCategory, ParameterType + + +class ConfigParameters: + """ + Contains information about all configuration parameters the user wants to search + """ + + def __init__(self): + self._parameters: Dict[str, ConfigParameter] = {} + + def add_parameter( + self, + name: str, + ptype: ParameterType, + category: ParameterCategory, + min_range: Optional[int] = None, + max_range: Optional[int] = None, + enumerated_list: List[Any] = [], + ): + self._check_for_illegal_input(category, min_range, max_range, enumerated_list) + + self._parameters[name] = ConfigParameter( + ptype, category, min_range, max_range, enumerated_list + ) + + def get_parameter(self, name: str) -> ConfigParameter: + return self._parameters[name] + + def get_type(self, name: str) -> ParameterType: + return self._parameters[name].ptype + + def get_category(self, name: str) -> ParameterCategory: + return self._parameters[name].category + + def get_range(self, name: str) -> Tuple[int, int]: + return (self._parameters[name].min_range, self._parameters[name].max_range) + + def get_list(self, name: str) -> List[Any]: + return self._parameters[name].enumerated_list + + def _check_for_illegal_input( + self, + category: ParameterCategory, + min_range: Optional[int], + max_range: Optional[int], + enumerated_list: List[Any], + ) -> None: + if category is ParameterCategory.LIST: + self._check_for_illegal_list_input(min_range, max_range, enumerated_list) + else: + if min_range is None or max_range is None: + raise TritonModelAnalyzerException( + f"Both min_range and max_range must be specified" + ) + + if min_range and max_range: + if min_range > max_range: + raise TritonModelAnalyzerException( + f"min_range cannot be larger than max_range" + ) + + def _check_for_illegal_list_input( + self, + min_range: Optional[int], + max_range: Optional[int], + enumerated_list: List[Any], + ) -> None: + if not enumerated_list: + raise TritonModelAnalyzerException( + f"enumerated_list must be specified for a LIST" + ) + elif min_range is not None: + raise TritonModelAnalyzerException( + f"min_range cannot be specified for a list" + ) + elif max_range is not None: + raise TritonModelAnalyzerException( + f"max_range cannot be specified for a list" + ) diff --git a/tests/test_config_parameters.py b/tests/test_config_parameters.py new file mode 100755 index 000000000..cdf1a9b93 --- /dev/null +++ b/tests/test_config_parameters.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 + +# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +from model_analyzer.config.generate.config_parameters import ( + ConfigParameters, + ParameterCategory, + ParameterType, +) +from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException + +from .common import test_result_collector as trc + + +class TestConfigParameters(trc.TestResultCollector): + def setUp(self): + self.config_parameters = ConfigParameters() + + self.config_parameters.add_parameter( + name="concurrency", + ptype=ParameterType.RUNTIME, + category=ParameterCategory.EXPONENTIAL, + min_range=0, + max_range=10, + ) + + self.config_parameters.add_parameter( + name="instance_count", + ptype=ParameterType.MODEL, + category=ParameterCategory.INTEGER, + min_range=1, + max_range=8, + ) + + self.config_parameters.add_parameter( + name="size", + ptype=ParameterType.BUILD, + category=ParameterCategory.LIST, + enumerated_list=["FP8", "FP16", "FP32"], + ) + + def test_exponential_parameter(self): + """ + Test exponential parameter, accessing dataclass directly + """ + + parameter = self.config_parameters.get_parameter("concurrency") + + self.assertEqual(ParameterType.RUNTIME, parameter.ptype) + self.assertEqual(ParameterCategory.EXPONENTIAL, parameter.category) + self.assertEqual(0, parameter.min_range) + self.assertEqual(10, parameter.max_range) + + def test_integer_parameter(self): + """ + Test integer parameter, using accessor methods + """ + + self.assertEqual( + ParameterType.MODEL, + self.config_parameters.get_type("instance_count"), + ) + self.assertEqual( + ParameterCategory.INTEGER, + self.config_parameters.get_category("instance_count"), + ) + self.assertEqual((1, 8), self.config_parameters.get_range("instance_count")) + + def test_list_parameter(self): + """ + Test list parameter, using accessor methods + """ + + self.assertEqual( + ParameterType.BUILD, + self.config_parameters.get_type("size"), + ) + self.assertEqual( + ParameterCategory.LIST, + self.config_parameters.get_category("size"), + ) + self.assertEqual( + ["FP8", "FP16", "FP32"], self.config_parameters.get_list("size") + ) + + def test_illegal_inputs(self): + """ + Check that an exception is raised for illegal input combos + """ + with self.assertRaises(TritonModelAnalyzerException): + self.config_parameters.add_parameter( + name="concurrency", + ptype=ParameterType.RUNTIME, + category=ParameterCategory.EXPONENTIAL, + max_range=10, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.config_parameters.add_parameter( + name="concurrency", + ptype=ParameterType.RUNTIME, + category=ParameterCategory.EXPONENTIAL, + min_range=0, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.config_parameters.add_parameter( + name="concurrency", + ptype=ParameterType.RUNTIME, + category=ParameterCategory.EXPONENTIAL, + min_range=10, + max_range=9, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.config_parameters.add_parameter( + name="size", + ptype=ParameterType.BUILD, + category=ParameterCategory.LIST, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.config_parameters.add_parameter( + name="size", + ptype=ParameterType.BUILD, + category=ParameterCategory.LIST, + enumerated_list=["FP8", "FP16", "FP32"], + min_range=0, + ) + + with self.assertRaises(TritonModelAnalyzerException): + self.config_parameters.add_parameter( + name="size", + ptype=ParameterType.BUILD, + category=ParameterCategory.LIST, + enumerated_list=["FP8", "FP16", "FP32"], + max_range=10, + ) + + +if __name__ == "__main__": + unittest.main()