Skip to content

Commit

Permalink
Adding illegal LLM checks w/ unit testing + some minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed Oct 31, 2023
1 parent f508625 commit 79a0937
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _create_request_period_list(self) -> List[int]:
return []

if self._model_parameters["request_period"]:
return sorted(self._model_parameters["period"])
return sorted(self._model_parameters["request_period"])
elif self._cli_config.run_config_search_disable:
return [DEFAULT_RUN_CONFIG_MIN_REQUEST_PERIOD]
else:
Expand Down
48 changes: 48 additions & 0 deletions model_analyzer/config/input/config_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _check_for_illegal_config_settings(
self._check_for_bls_incompatibility(args, yaml_config)
self._check_for_concurrency_rate_request_conflicts(args, yaml_config)
self._check_for_config_search_rate_request_conflicts(args, yaml_config)
self._check_for_llm_incompatibility(args, yaml_config)

def _set_field_values(
self, args: Namespace, yaml_config: Optional[Dict[str, List]]
Expand Down Expand Up @@ -398,6 +399,53 @@ def _check_for_config_search_rate_request_conflicts(
f"\nCannot have both `run-config-search-max-request-rate` and `run-config-search-min/max-concurrency` specified in the config/CLI."
)

def _check_for_llm_incompatibility(
self, args: Namespace, yaml_config: Optional[Dict[str, List]]
) -> None:
if not self._get_config_value("llm_search_enable", args, yaml_config):
return

if (
self._get_config_value("run_config_search_mode", args, yaml_config)
== "quick"
):
raise TritonModelAnalyzerException(
f"\nLLM models are not supported in quick search. Please use brute search mode."
)

self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_min_model_batch_size"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_max_model_batch_size"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_min_concurrency"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_max_concurrency"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_min_request_rate"
)
self._check_for_illegal_llm_option(
args, yaml_config, "run_config_search_max_request_rate"
)
self._check_for_illegal_llm_option(
args, yaml_config, "request_rate_search_enable"
)
self._check_for_illegal_llm_option(args, yaml_config, "concurrency")
self._check_for_illegal_llm_option(args, yaml_config, "latency_budget")
self._check_for_illegal_llm_option(args, yaml_config, "min_throughput")

def _check_for_illegal_llm_option(
self, args: Namespace, yaml_config: Optional[Dict[str, List]], option: str
) -> None:
if self._get_config_value(option, args, yaml_config):
raise TritonModelAnalyzerException(
f"\nLLM models do not support setting the `{option}` option when profiling."
)

def _preprocess_and_verify_arguments(self):
"""
Enforces some rules on the config.
Expand Down
4 changes: 4 additions & 0 deletions model_analyzer/config/input/config_command_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ def set_config_values(self, args):

super().set_config_values(args)

# TODO: Update this when adding support for detailed reporting
def is_llm_model(self) -> bool:
return False

def _preprocess_and_verify_arguments(self):
"""
Enforces some rules on the config.
Expand Down
9 changes: 8 additions & 1 deletion model_analyzer/perf_analyzer/perf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class PerfAnalyzerConfig:
"collect-metrics",
]

llm_args = ["text-input-length", "max-tokens"]

def __init__(self):
"""
Construct a PerfAnalyzerConfig
Expand Down Expand Up @@ -152,7 +154,12 @@ def allowed_keys(cls):
passed into perf_analyzer
"""

return cls.perf_analyzer_args + cls.input_to_options + cls.input_to_verbose
return (
cls.perf_analyzer_args
+ cls.input_to_options
+ cls.input_to_verbose
+ cls.llm_args
)

@classmethod
def additive_keys(cls):
Expand Down
74 changes: 74 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2340,6 +2340,80 @@ def _test_arg_conflict(
with self.assertRaises(TritonModelAnalyzerException):
self._evaluate_config(args, yaml_content)

def test_llm_mode_rcs(self):
"""
Test RCS options for an LLM model
"""
yaml_content = ""

self._test_llm_mode_case(
yaml_content,
["--run-config-search-mode", "brute"],
is_legal=True,
use_value=False,
use_list=False,
)
self._test_llm_mode_case(
yaml_content,
["--run-config-search-mode", "quick"],
use_value=False,
use_list=False,
)

self._test_llm_mode_case(
yaml_content, ["--run-config-search-min-model-batch-size"]
)
self._test_llm_mode_case(
yaml_content, ["--run-config-search-max-model-batch-size"]
)
self._test_llm_mode_case(yaml_content, ["--run-config-search-min-concurrency"])
self._test_llm_mode_case(yaml_content, ["--run-config-search-max-concurrency"])
self._test_llm_mode_case(yaml_content, ["--run-config-search-min-request-rate"])
self._test_llm_mode_case(yaml_content, ["--run-config-search-max-request-rate"])
self._test_llm_mode_case(
yaml_content,
["--request-rate-search-enable"],
use_value=False,
use_list=False,
)
self._test_llm_mode_case(yaml_content, ["--concurrency"])
self._test_llm_mode_case(yaml_content, ["--latency-budget"])
self._test_llm_mode_case(yaml_content, ["--min-throughput"])

def _test_llm_mode_case(
self,
yaml_content: Optional[Dict[str, List]],
options_string: str,
is_legal: bool = False,
use_value: bool = True,
use_list: bool = True,
) -> None:
"""
Tests that options raise exceptions in LLM mode
"""
args = [
"model-analyzer",
"profile",
"--model-repository",
"cli-repository",
"--profile-models",
"test_llm_modelA",
"--llm-search-enable",
]

args.extend(options_string)

if use_value:
args.append("1")
elif use_list:
args.append(["1", "2", "4"])

if is_legal:
self._evaluate_config(args, yaml_content, subcommand="profile")
else:
with self.assertRaises(TritonModelAnalyzerException):
self._evaluate_config(args, yaml_content, subcommand="profile")


if __name__ == "__main__":
unittest.main()

0 comments on commit 79a0937

Please sign in to comment.