-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Class to hold info about parameters #868
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from dataclasses import dataclass | ||
from enum import Enum, auto | ||
from typing import Any, List, Optional | ||
|
||
|
||
class ParameterType(Enum): | ||
MODEL = auto() | ||
RUNTIME = auto() | ||
BUILD = auto() | ||
|
||
|
||
class ParameterCategory(Enum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SearchParameterType |
||
INTEGER = auto() | ||
EXPONENTIAL = auto() | ||
LIST = auto() | ||
|
||
|
||
@dataclass | ||
class ConfigParameter: | ||
""" | ||
A dataclass that holds information about a configuration parameter | ||
""" | ||
|
||
ptype: ParameterType | ||
category: ParameterCategory | ||
|
||
# These are only applicable to INTEGER and EXPONENTIAL categories | ||
min_range: Optional[int] = None | ||
max_range: Optional[int] = None | ||
|
||
# This is only applicable to LIST category | ||
enumerated_list: List[Any] = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, List, Optional, Tuple | ||
|
||
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException | ||
|
||
from .config_parameter import ConfigParameter, ParameterCategory, ParameterType | ||
|
||
|
||
class ConfigParameters: | ||
""" | ||
Contains information about all configuration parameters the user wants to search | ||
""" | ||
|
||
def __init__(self): | ||
self._parameters: Dict[str, ConfigParameter] = {} | ||
|
||
def add_parameter( | ||
self, | ||
name: str, | ||
ptype: ParameterType, | ||
category: ParameterCategory, | ||
min_range: Optional[int] = None, | ||
max_range: Optional[int] = None, | ||
enumerated_list: List[Any] = [], | ||
): | ||
nv-braf marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._check_for_illegal_input(category, min_range, max_range, enumerated_list) | ||
|
||
self._parameters[name] = ConfigParameter( | ||
ptype, category, min_range, max_range, enumerated_list | ||
) | ||
|
||
def get_parameter(self, name: str) -> ConfigParameter: | ||
return self._parameters[name] | ||
|
||
def get_type(self, name: str) -> ParameterType: | ||
return self._parameters[name].ptype | ||
|
||
def get_category(self, name: str) -> ParameterCategory: | ||
return self._parameters[name].category | ||
|
||
def get_range(self, name: str) -> Tuple[int, int]: | ||
return (self._parameters[name].min_range, self._parameters[name].max_range) | ||
|
||
def get_list(self, name: str) -> List[Any]: | ||
return self._parameters[name].enumerated_list | ||
nv-braf marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _check_for_illegal_input( | ||
self, | ||
category: ParameterCategory, | ||
min_range: Optional[int], | ||
max_range: Optional[int], | ||
enumerated_list: List[Any], | ||
) -> None: | ||
if category is ParameterCategory.LIST: | ||
self._check_for_illegal_list_input(min_range, max_range, enumerated_list) | ||
else: | ||
if min_range is None or max_range is None: | ||
raise TritonModelAnalyzerException( | ||
f"Both min_range and max_range must be specified" | ||
) | ||
|
||
if min_range and max_range: | ||
if min_range > max_range: | ||
raise TritonModelAnalyzerException( | ||
f"min_range cannot be larger than max_range" | ||
) | ||
|
||
def _check_for_illegal_list_input( | ||
self, | ||
min_range: Optional[int], | ||
max_range: Optional[int], | ||
enumerated_list: List[Any], | ||
) -> None: | ||
if not enumerated_list: | ||
raise TritonModelAnalyzerException( | ||
f"enumerated_list must be specified for a LIST" | ||
nv-braf marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
elif min_range is not None: | ||
raise TritonModelAnalyzerException( | ||
f"min_range cannot be specified for a list" | ||
) | ||
elif max_range is not None: | ||
raise TritonModelAnalyzerException( | ||
f"max_range cannot be specified for a list" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
from unittest.mock import MagicMock, patch | ||
|
||
|
||
from model_analyzer.config.generate.config_parameters import ( | ||
ConfigParameters, | ||
ParameterCategory, | ||
ParameterType, | ||
) | ||
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException | ||
|
||
from .common import test_result_collector as trc | ||
|
||
|
||
class TestConfigParameters(trc.TestResultCollector): | ||
def setUp(self): | ||
self.config_parameters = ConfigParameters() | ||
|
||
self.config_parameters.add_parameter( | ||
name="concurrency", | ||
ptype=ParameterType.RUNTIME, | ||
category=ParameterCategory.EXPONENTIAL, | ||
min_range=0, | ||
max_range=10, | ||
) | ||
|
||
self.config_parameters.add_parameter( | ||
name="instance_count", | ||
ptype=ParameterType.MODEL, | ||
category=ParameterCategory.INTEGER, | ||
min_range=1, | ||
max_range=8, | ||
) | ||
|
||
self.config_parameters.add_parameter( | ||
name="size", | ||
ptype=ParameterType.BUILD, | ||
category=ParameterCategory.LIST, | ||
enumerated_list=["FP8", "FP16", "FP32"], | ||
) | ||
|
||
def test_exponential_parameter(self): | ||
""" | ||
Test exponential parameter, accessing dataclass directly | ||
""" | ||
|
||
parameter = self.config_parameters.get_parameter("concurrency") | ||
|
||
self.assertEqual(ParameterType.RUNTIME, parameter.ptype) | ||
self.assertEqual(ParameterCategory.EXPONENTIAL, parameter.category) | ||
self.assertEqual(0, parameter.min_range) | ||
self.assertEqual(10, parameter.max_range) | ||
|
||
def test_integer_parameter(self): | ||
""" | ||
Test integer parameter, using accessor methods | ||
""" | ||
|
||
self.assertEqual( | ||
ParameterType.MODEL, | ||
self.config_parameters.get_type("instance_count"), | ||
) | ||
self.assertEqual( | ||
ParameterCategory.INTEGER, | ||
self.config_parameters.get_category("instance_count"), | ||
) | ||
self.assertEqual((1, 8), self.config_parameters.get_range("instance_count")) | ||
|
||
def test_list_parameter(self): | ||
""" | ||
Test list parameter, using accessor methods | ||
""" | ||
|
||
self.assertEqual( | ||
ParameterType.BUILD, | ||
self.config_parameters.get_type("size"), | ||
) | ||
self.assertEqual( | ||
ParameterCategory.LIST, | ||
self.config_parameters.get_category("size"), | ||
) | ||
self.assertEqual( | ||
["FP8", "FP16", "FP32"], self.config_parameters.get_list("size") | ||
) | ||
|
||
def test_illegal_inputs(self): | ||
""" | ||
Check that an exception is raised for illegal input combos | ||
""" | ||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="concurrency", | ||
ptype=ParameterType.RUNTIME, | ||
category=ParameterCategory.EXPONENTIAL, | ||
max_range=10, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="concurrency", | ||
ptype=ParameterType.RUNTIME, | ||
category=ParameterCategory.EXPONENTIAL, | ||
min_range=0, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="concurrency", | ||
ptype=ParameterType.RUNTIME, | ||
category=ParameterCategory.EXPONENTIAL, | ||
min_range=10, | ||
max_range=9, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="size", | ||
ptype=ParameterType.BUILD, | ||
category=ParameterCategory.LIST, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="size", | ||
ptype=ParameterType.BUILD, | ||
category=ParameterCategory.LIST, | ||
enumerated_list=["FP8", "FP16", "FP32"], | ||
min_range=0, | ||
) | ||
|
||
with self.assertRaises(TritonModelAnalyzerException): | ||
self.config_parameters.add_parameter( | ||
name="size", | ||
ptype=ParameterType.BUILD, | ||
category=ParameterCategory.LIST, | ||
enumerated_list=["FP8", "FP16", "FP32"], | ||
max_range=10, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SearchUsageType