From adef23fb879d372d8e22411d1a7f32cc437c6c57 Mon Sep 17 00:00:00 2001 From: braf Date: Tue, 31 Oct 2023 21:11:50 +0000 Subject: [PATCH] General cleanup --- model_analyzer/analyzer.py | 5 ++++- .../config/generate/perf_analyzer_config_generator.py | 6 +++++- model_analyzer/config/input/config_defaults.py | 2 +- model_analyzer/record/metrics_manager.py | 5 +++++ tests/test_model_manager.py | 4 +++- 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/model_analyzer/analyzer.py b/model_analyzer/analyzer.py index 750c2a8ba..9f6ce25d7 100755 --- a/model_analyzer/analyzer.py +++ b/model_analyzer/analyzer.py @@ -136,7 +136,10 @@ def profile( if not self._config.skip_summary_reports: self._create_summary_tables(verbose) - self._create_summary_reports(mode) + + # TODO: need to figure out summary reporting for LLMs + if not self._config.is_llm_model(): + self._create_summary_reports(mode) # FIXME: need to figure out detailed reporting for LLMs if not self._config.is_llm_model(): diff --git a/model_analyzer/config/generate/perf_analyzer_config_generator.py b/model_analyzer/config/generate/perf_analyzer_config_generator.py index 7f134813a..10c86e610 100755 --- a/model_analyzer/config/generate/perf_analyzer_config_generator.py +++ b/model_analyzer/config/generate/perf_analyzer_config_generator.py @@ -16,6 +16,7 @@ import json import logging +import os from itertools import repeat from typing import Any, Dict, Generator, List, Optional, Tuple @@ -328,7 +329,7 @@ def _create_request_period_list(self) -> List[int]: return [] if self._model_parameters["request_period"]: - return sorted(self._model_parameters["period"]) + return sorted(self._model_parameters["request_period"]) elif self._cli_config.run_config_search_disable: return [DEFAULT_RUN_CONFIG_MIN_REQUEST_PERIOD] else: @@ -448,6 +449,9 @@ def _modify_text_in_input_dict(self, text_input_length: int) -> Dict: def _write_modified_input_dict_to_file( self, modified_input_dict: Dict, input_json_filename: str ) -> None: + if not os.path.exists(DEFAULT_INPUT_JSON_PATH): + os.makedirs(DEFAULT_INPUT_JSON_PATH) + with open(input_json_filename, "w") as f: json.dump(modified_input_dict, f) diff --git a/model_analyzer/config/input/config_defaults.py b/model_analyzer/config/input/config_defaults.py index fd48812b8..fb0b62ee8 100755 --- a/model_analyzer/config/input/config_defaults.py +++ b/model_analyzer/config/input/config_defaults.py @@ -38,7 +38,7 @@ DEFAULT_SKIP_SUMMARY_REPORTS = False DEFAULT_SKIP_DETAILED_REPORTS = False DEFAULT_OUTPUT_MODEL_REPOSITORY = os.path.join(os.getcwd(), "output_model_repository") -DEFAULT_INPUT_JSON_PATH = os.getcwd() +DEFAULT_INPUT_JSON_PATH = os.path.join(os.getcwd(), "input_json_dir") DEFAULT_OVERRIDE_OUTPUT_REPOSITORY_FLAG = False DEFAULT_BATCH_SIZES = 1 DEFAULT_MAX_RETRIES = 50 diff --git a/model_analyzer/record/metrics_manager.py b/model_analyzer/record/metrics_manager.py index e703e19a2..10459a76f 100755 --- a/model_analyzer/record/metrics_manager.py +++ b/model_analyzer/record/metrics_manager.py @@ -16,6 +16,7 @@ import logging import os +import shutil import time from collections import defaultdict from typing import Dict, List, Optional, Tuple @@ -27,6 +28,7 @@ from model_analyzer.config.generate.base_model_config_generator import ( BaseModelConfigGenerator, ) +from model_analyzer.config.input.config_defaults import DEFAULT_INPUT_JSON_PATH from model_analyzer.config.run.run_config import RunConfig from model_analyzer.constants import LOGGER_NAME, PA_ERROR_LOG_FILENAME from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException @@ -309,6 +311,9 @@ def profile_models(self, run_config: RunConfig) -> Optional[RunConfigMeasurement def finalize(self): self._server.stop() + if os.path.exists(DEFAULT_INPUT_JSON_PATH): + shutil.rmtree(DEFAULT_INPUT_JSON_PATH) + def _create_model_variants(self, run_config: RunConfig) -> None: """ Creates and fills all model variant directories diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index 0370c6c77..fcbd8eee8 100755 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -1294,7 +1294,9 @@ def _test_model_manager(self, yaml_content, expected_ranges, args=None): MagicMock(), ) - model_manager.run_models([config.profile_models[0]]) + with patch("shutil.rmtree"): + model_manager.run_models([config.profile_models[0]]) + self.mock_model_config.stop() self._check_results(model_manager, expected_ranges)