-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Class to hold info about parameters (#868)
* Initial code for ConfigParameters class * Fixing codeql issue * Fixes based on review
- Loading branch information
Showing
3 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from dataclasses import dataclass | ||
from enum import Enum, auto | ||
from typing import Any, List, Optional | ||
|
||
|
||
class ParameterType(Enum): | ||
MODEL = auto() | ||
RUNTIME = auto() | ||
BUILD = auto() | ||
|
||
|
||
class ParameterCategory(Enum): | ||
INTEGER = auto() | ||
EXPONENTIAL = auto() | ||
LIST = auto() | ||
|
||
|
||
@dataclass | ||
class ConfigParameter: | ||
""" | ||
A dataclass that holds information about a configuration parameter | ||
""" | ||
|
||
ptype: ParameterType | ||
category: ParameterCategory | ||
|
||
# These are only applicable to INTEGER and EXPONENTIAL categories | ||
min_range: Optional[int] = None | ||
max_range: Optional[int] = None | ||
|
||
# This is only applicable to LIST category | ||
enumerated_list: List[Any] = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, List, Optional, Tuple | ||
|
||
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException | ||
|
||
from .config_parameter import ConfigParameter, ParameterCategory, ParameterType | ||
|
||
|
||
class ConfigParameters: | ||
""" | ||
Contains information about all configuration parameters the user wants to search | ||
""" | ||
|
||
def __init__(self): | ||
self._parameters: Dict[str, ConfigParameter] = {} | ||
|
||
def add_parameter( | ||
self, | ||
name: str, | ||
ptype: ParameterType, | ||
category: ParameterCategory, | ||
min_range: Optional[int] = None, | ||
max_range: Optional[int] = None, | ||
enumerated_list: List[Any] = [], | ||
) -> None: | ||
self._check_for_illegal_input(category, min_range, max_range, enumerated_list) | ||
|
||
self._parameters[name] = ConfigParameter( | ||
ptype, category, min_range, max_range, enumerated_list | ||
) | ||
|
||
def get_parameter(self, name: str) -> ConfigParameter: | ||
return self._parameters[name] | ||
|
||
def get_type(self, name: str) -> ParameterType: | ||
return self._parameters[name].ptype | ||
|
||
def get_category(self, name: str) -> ParameterCategory: | ||
return self._parameters[name].category | ||
|
||
def get_range(self, name: str) -> Tuple[int, int]: | ||
return (self._parameters[name].min_range, self._parameters[name].max_range) | ||
|
||
def get_list(self, name: str) -> List[Any]: | ||
return self._parameters[name].enumerated_list | ||
|
||
def _check_for_illegal_input( | ||
self, | ||
category: ParameterCategory, | ||
min_range: Optional[int], | ||
max_range: Optional[int], | ||
enumerated_list: List[Any], | ||
) -> None: | ||
if category is ParameterCategory.LIST: | ||
self._check_for_illegal_list_input(min_range, max_range, enumerated_list) | ||
else: | ||
if min_range is None or max_range is None: | ||
raise TritonModelAnalyzerException( | ||
f"Both min_range and max_range must be specified" | ||
) | ||
|
||
if min_range and max_range: | ||
if min_range > max_range: | ||
raise TritonModelAnalyzerException( | ||
f"min_range cannot be larger than max_range" | ||
) | ||
|
||
def _check_for_illegal_list_input( | ||
self, | ||
min_range: Optional[int], | ||
max_range: Optional[int], | ||
enumerated_list: List[Any], | ||
) -> None: | ||
if not enumerated_list: | ||
raise TritonModelAnalyzerException( | ||
f"enumerated_list must be specified for a ParameterCategory.LIST" | ||
) | ||
elif min_range is not None: | ||
raise TritonModelAnalyzerException( | ||
f"min_range cannot be specified for a ParameterCategory.LIST" | ||
) | ||
elif max_range is not None: | ||
raise TritonModelAnalyzerException( | ||
f"max_range cannot be specified for a ParameterCategory.LIST" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
from model_analyzer.config.generate.config_parameters import ( | ||
ConfigParameters, | ||
ParameterCategory, | ||
ParameterType, | ||
) | ||
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException | ||
|
||
from .common import test_result_collector as trc | ||
|
||
|
||
class TestConfigParameters(trc.TestResultCollector): | ||
def setUp(self): | ||
self.config_parameters = ConfigParameters() | ||
|
||
self.config_parameters.add_parameter( | ||
name="concurrency", | ||
ptype=ParameterType.RUNTIME, | ||
category=ParameterCategory.EXPONENTIAL, | ||
min_range=0, | ||
max_range=10, | ||
) | ||
|
||
self.config_parameters.add_parameter( | ||
name="instance_count", | ||
ptype=ParameterType.MODEL, | ||
category=ParameterCategory.INTEGER, | ||
min_range=1, | ||
max_range=8, | ||
) | ||
|
||
self.config_parameters.add_parameter( | ||
name="size", | ||
ptype=ParameterType.BUILD, | ||
category=ParameterCategory.LIST, | ||
enumerated_list=["FP8", "FP16", "FP32"], | ||
) | ||
|
||
def test_exponential_parameter(self): | ||
""" | ||
Test exponential parameter, accessing dataclass directly | ||
""" | ||
|
||
parameter = self.config_parameters.get_parameter("concurrency") | ||
|
||
self.assertEqual(ParameterType.RUNTIME, parameter.ptype) | ||
self.assertEqual(ParameterCategory.EXPONENTIAL, parameter.category) | ||
self.assertEqual(0, parameter.min_range) | ||
self.assertEqual(10, parameter.max_range) | ||
|
||
def test_integer_parameter(self): | ||
""" | ||
Test integer parameter, using accessor methods | ||
""" | ||
|
||
self.assertEqual( | ||
ParameterType.MODEL, | ||
self.config_parameters.get_type("instance_count"), | ||
) | ||
self.assertEqual( | ||
ParameterCategory.INTEGER, | ||
self.config_parameters.get_category("instance_count"), | ||
) | ||
self.assertEqual((1, 8), self.config_parameters.get_range("instance_count")) | ||
|
||
def test_list_parameter(self): | ||
""" | ||
Test list parameter, using accessor methods | ||
""" | ||
|
||
self.assertEqual( | ||
ParameterType.BUILD, | ||
self.config_parameters.get_type("size"), | ||
) | ||
self.assertEqual( | ||
ParameterCategory.LIST, | ||
self.config_parameters.get_category("size"), | ||
) | ||
self.assertEqual( | ||
["FP8", "FP16", "FP32"], self.config_parameters.get_list("size") | ||
) | ||
|
||
def test_illegal_inputs(self): | ||
""" | ||
Check that an exception is raised for illegal input combos | ||
""" | ||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="concurrency", | ||
ptype=ParameterType.RUNTIME, | ||
category=ParameterCategory.EXPONENTIAL, | ||
max_range=10, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="concurrency", | ||
ptype=ParameterType.RUNTIME, | ||
category=ParameterCategory.EXPONENTIAL, | ||
min_range=0, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="concurrency", | ||
ptype=ParameterType.RUNTIME, | ||
category=ParameterCategory.EXPONENTIAL, | ||
min_range=10, | ||
max_range=9, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="size", | ||
ptype=ParameterType.BUILD, | ||
category=ParameterCategory.LIST, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="size", | ||
ptype=ParameterType.BUILD, | ||
category=ParameterCategory.LIST, | ||
enumerated_list=["FP8", "FP16", "FP32"], | ||
min_range=0, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="size", | ||
ptype=ParameterType.BUILD, | ||
category=ParameterCategory.LIST, | ||
enumerated_list=["FP8", "FP16", "FP32"], | ||
max_range=10, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |