Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding illegal LLM checks w/ unit testing + some minor cleanup #781

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _create_request_period_list(self) -> List[int]:
return []

if self._model_parameters["request_period"]:
return sorted(self._model_parameters["period"])
return sorted(self._model_parameters["request_period"])
elif self._cli_config.run_config_search_disable:
return [DEFAULT_RUN_CONFIG_MIN_REQUEST_PERIOD]
else:
Expand Down
48 changes: 48 additions & 0 deletions model_analyzer/config/input/config_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _check_for_illegal_config_settings(
self._check_for_bls_incompatibility(args, yaml_config)
self._check_for_concurrency_rate_request_conflicts(args, yaml_config)
self._check_for_config_search_rate_request_conflicts(args, yaml_config)
self._check_for_llm_incompatibility(args, yaml_config)

def _set_field_values(
self, args: Namespace, yaml_config: Optional[Dict[str, List]]
Expand Down Expand Up @@ -398,6 +399,53 @@ def _check_for_config_search_rate_request_conflicts(
f"\nCannot have both `run-config-search-max-request-rate` and `run-config-search-min/max-concurrency` specified in the config/CLI."
)

def _check_for_llm_incompatibility(
self, args: Namespace, yaml_config: Optional[Dict[str, List]]
) -> None:
if not self._get_config_value("llm_search_enable", args, yaml_config):
return

if (
self._get_config_value("run_config_search_mode", args, yaml_config)
== "quick"
):
raise TritonModelAnalyzerException(
f"\nLLM models are not supported in quick search. Please use brute search mode."
)

self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_min_model_batch_size"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_max_model_batch_size"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_min_concurrency"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_max_concurrency"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_min_request_rate"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_max_request_rate"
)
self._check_for_illegal_llm_option(
args, yaml_config, "request_rate_search_enable"
)
self._check_for_illegal_llm_option(args, yaml_config, "concurrency")
self._check_for_illegal_llm_option(args, yaml_config, "latency_budget")
self._check_for_illegal_llm_option(args, yaml_config, "min_throughput")

def _check_for_illegal_llm_option(
self, args: Namespace, yaml_config: Optional[Dict[str, List]], option: str
) -> None:
if self._get_config_value(option, args, yaml_config):
raise TritonModelAnalyzerException(
f"\nLLM models do not support setting the `{option}` option when profiling."
)

def _preprocess_and_verify_arguments(self):
"""
Enforces some rules on the config.
Expand Down
4 changes: 4 additions & 0 deletions model_analyzer/config/input/config_command_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def set_config_values(self, args):

super().set_config_values(args)

# TODO TMA-1443: Update this when adding support for detailed reporting
def is_llm_model(self) -> bool:
return False

def _preprocess_and_verify_arguments(self):
"""
Enforces some rules on the config.
Expand Down
9 changes: 8 additions & 1 deletion model_analyzer/perf_analyzer/perf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class PerfAnalyzerConfig:
"collect-metrics",
]

llm_args = ["text-input-length", "max-tokens"]

def __init__(self):
"""
Construct a PerfAnalyzerConfig
Expand Down Expand Up @@ -152,7 +154,12 @@ def allowed_keys(cls):
passed into perf_analyzer
"""

return cls.perf_analyzer_args + cls.input_to_options + cls.input_to_verbose
return (
cls.perf_analyzer_args
+ cls.input_to_options
+ cls.input_to_verbose
+ cls.llm_args
)

@classmethod
def additive_keys(cls):
Expand Down
74 changes: 74 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2340,6 +2340,80 @@ def _test_arg_conflict(
with self.assertRaises(TritonModelAnalyzerException):
self._evaluate_config(args, yaml_content)

def test_llm_mode_rcs(self):
"""
Test RCS options for an LLM model
"""
yaml_content = ""

self._test_llm_mode_case(
yaml_content,
["--run-config-search-mode", "brute"],
is_legal=True,
use_value=False,
use_list=False,
)
self._test_llm_mode_case(
yaml_content,
["--run-config-search-mode", "quick"],
use_value=False,
use_list=False,
)

self._test_llm_mode_case(
yaml_content, ["--run-config-search-min-model-batch-size"]
)
self._test_llm_mode_case(
yaml_content, ["--run-config-search-max-model-batch-size"]
)
self._test_llm_mode_case(yaml_content, ["--run-config-search-min-concurrency"])
self._test_llm_mode_case(yaml_content, ["--run-config-search-max-concurrency"])
self._test_llm_mode_case(yaml_content, ["--run-config-search-min-request-rate"])
self._test_llm_mode_case(yaml_content, ["--run-config-search-max-request-rate"])
self._test_llm_mode_case(
yaml_content,
["--request-rate-search-enable"],
use_value=False,
use_list=False,
)
self._test_llm_mode_case(yaml_content, ["--concurrency"])
self._test_llm_mode_case(yaml_content, ["--latency-budget"])
self._test_llm_mode_case(yaml_content, ["--min-throughput"])

def _test_llm_mode_case(
self,
yaml_content: Optional[Dict[str, List]],
options_string: str,
is_legal: bool = False,
use_value: bool = True,
use_list: bool = True,
) -> None:
"""
Tests that options raise exceptions in LLM mode
"""
args = [
"model-analyzer",
"profile",
"--model-repository",
"cli-repository",
"--profile-models",
"test_llm_modelA",
"--llm-search-enable",
]

args.extend(options_string)

if use_value:
args.append("1")
elif use_list:
args.append(["1", "2", "4"])

if is_legal:
self._evaluate_config(args, yaml_content, subcommand="profile")
else:
with self.assertRaises(TritonModelAnalyzerException):
self._evaluate_config(args, yaml_content, subcommand="profile")


if __name__ == "__main__":
unittest.main()
Loading