Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class to hold info about parameters #868

Merged
merged 3 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions model_analyzer/config/generate/config_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, List, Optional


class ParameterType(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SearchUsageType

MODEL = auto()
RUNTIME = auto()
BUILD = auto()


class ParameterCategory(Enum):
Copy link
Contributor

@dyastremsky dyastremsky May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SearchParameterType

INTEGER = auto()
EXPONENTIAL = auto()
LIST = auto()


@dataclass
class ConfigParameter:
"""
A dataclass that holds information about a configuration parameter
"""

ptype: ParameterType
category: ParameterCategory

# These are only applicable to INTEGER and EXPONENTIAL categories
min_range: Optional[int] = None
max_range: Optional[int] = None

# This is only applicable to LIST category
enumerated_list: List[Any] = []
100 changes: 100 additions & 0 deletions model_analyzer/config/generate/config_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional, Tuple

from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException

from .config_parameter import ConfigParameter, ParameterCategory, ParameterType


class ConfigParameters:
"""
Contains information about all configuration parameters the user wants to search
"""

def __init__(self):
self._parameters: Dict[str, ConfigParameter] = {}

def add_parameter(
self,
name: str,
ptype: ParameterType,
category: ParameterCategory,
min_range: Optional[int] = None,
max_range: Optional[int] = None,
enumerated_list: List[Any] = [],
) -> None:
self._check_for_illegal_input(category, min_range, max_range, enumerated_list)

self._parameters[name] = ConfigParameter(
ptype, category, min_range, max_range, enumerated_list
)

def get_parameter(self, name: str) -> ConfigParameter:
return self._parameters[name]

def get_type(self, name: str) -> ParameterType:
return self._parameters[name].ptype

def get_category(self, name: str) -> ParameterCategory:
return self._parameters[name].category

def get_range(self, name: str) -> Tuple[int, int]:
return (self._parameters[name].min_range, self._parameters[name].max_range)

def get_list(self, name: str) -> List[Any]:
return self._parameters[name].enumerated_list
nv-braf marked this conversation as resolved.
Show resolved Hide resolved

def _check_for_illegal_input(
self,
category: ParameterCategory,
min_range: Optional[int],
max_range: Optional[int],
enumerated_list: List[Any],
) -> None:
if category is ParameterCategory.LIST:
self._check_for_illegal_list_input(min_range, max_range, enumerated_list)
else:
if min_range is None or max_range is None:
raise TritonModelAnalyzerException(
f"Both min_range and max_range must be specified"
)

if min_range and max_range:
if min_range > max_range:
raise TritonModelAnalyzerException(
f"min_range cannot be larger than max_range"
)

def _check_for_illegal_list_input(
self,
min_range: Optional[int],
max_range: Optional[int],
enumerated_list: List[Any],
) -> None:
if not enumerated_list:
raise TritonModelAnalyzerException(
f"enumerated_list must be specified for a ParameterCategory.LIST"
)
elif min_range is not None:
raise TritonModelAnalyzerException(
f"min_range cannot be specified for a ParameterCategory.LIST"
)
elif max_range is not None:
raise TritonModelAnalyzerException(
f"max_range cannot be specified for a ParameterCategory.LIST"
)
156 changes: 156 additions & 0 deletions tests/test_config_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#!/usr/bin/env python3

# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from model_analyzer.config.generate.config_parameters import (
ConfigParameters,
ParameterCategory,
ParameterType,
)
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException

from .common import test_result_collector as trc


class TestConfigParameters(trc.TestResultCollector):
def setUp(self):
self.config_parameters = ConfigParameters()

self.config_parameters.add_parameter(
name="concurrency",
ptype=ParameterType.RUNTIME,
category=ParameterCategory.EXPONENTIAL,
min_range=0,
max_range=10,
)

self.config_parameters.add_parameter(
name="instance_count",
ptype=ParameterType.MODEL,
category=ParameterCategory.INTEGER,
min_range=1,
max_range=8,
)

self.config_parameters.add_parameter(
name="size",
ptype=ParameterType.BUILD,
category=ParameterCategory.LIST,
enumerated_list=["FP8", "FP16", "FP32"],
)

def test_exponential_parameter(self):
"""
Test exponential parameter, accessing dataclass directly
"""

parameter = self.config_parameters.get_parameter("concurrency")

self.assertEqual(ParameterType.RUNTIME, parameter.ptype)
self.assertEqual(ParameterCategory.EXPONENTIAL, parameter.category)
self.assertEqual(0, parameter.min_range)
self.assertEqual(10, parameter.max_range)

def test_integer_parameter(self):
"""
Test integer parameter, using accessor methods
"""

self.assertEqual(
ParameterType.MODEL,
self.config_parameters.get_type("instance_count"),
)
self.assertEqual(
ParameterCategory.INTEGER,
self.config_parameters.get_category("instance_count"),
)
self.assertEqual((1, 8), self.config_parameters.get_range("instance_count"))

def test_list_parameter(self):
"""
Test list parameter, using accessor methods
"""

self.assertEqual(
ParameterType.BUILD,
self.config_parameters.get_type("size"),
)
self.assertEqual(
ParameterCategory.LIST,
self.config_parameters.get_category("size"),
)
self.assertEqual(
["FP8", "FP16", "FP32"], self.config_parameters.get_list("size")
)

def test_illegal_inputs(self):
"""
Check that an exception is raised for illegal input combos
"""
with self.assertRaises(TritonModelAnalyzerException):
self.config_parameters.add_parameter(
name="concurrency",
ptype=ParameterType.RUNTIME,
category=ParameterCategory.EXPONENTIAL,
max_range=10,
)

with self.assertRaises(TritonModelAnalyzerException):
self.config_parameters.add_parameter(
name="concurrency",
ptype=ParameterType.RUNTIME,
category=ParameterCategory.EXPONENTIAL,
min_range=0,
)

with self.assertRaises(TritonModelAnalyzerException):
self.config_parameters.add_parameter(
name="concurrency",
ptype=ParameterType.RUNTIME,
category=ParameterCategory.EXPONENTIAL,
min_range=10,
max_range=9,
)

with self.assertRaises(TritonModelAnalyzerException):
self.config_parameters.add_parameter(
name="size",
ptype=ParameterType.BUILD,
category=ParameterCategory.LIST,
)

with self.assertRaises(TritonModelAnalyzerException):
self.config_parameters.add_parameter(
name="size",
ptype=ParameterType.BUILD,
category=ParameterCategory.LIST,
enumerated_list=["FP8", "FP16", "FP32"],
min_range=0,
)

with self.assertRaises(TritonModelAnalyzerException):
self.config_parameters.add_parameter(
name="size",
ptype=ParameterType.BUILD,
category=ParameterCategory.LIST,
enumerated_list=["FP8", "FP16", "FP32"],
max_range=10,
)


if __name__ == "__main__":
unittest.main()
Loading