Skip to content

Commit

Permalink
Base Optuna class plus unit testing
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed May 15, 2024
1 parent 2b40c67 commit 509a4b5
Show file tree
Hide file tree
Showing 3 changed files with 448 additions and 0 deletions.
307 changes: 307 additions & 0 deletions model_analyzer/config/generate/optuna_run_config_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
#!/usr/bin/env python3

# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from sys import maxsize

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'maxsize' is not used.
from typing import Any, Dict, Generator, List, Optional, Tuple, Union

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Tuple' is not used.

import optuna

from model_analyzer.config.generate.base_model_config_generator import (
BaseModelConfigGenerator,
)
from model_analyzer.config.generate.brute_run_config_generator import (
BruteRunConfigGenerator,
)
from model_analyzer.config.generate.coordinate import Coordinate

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Coordinate' is not used.
from model_analyzer.config.generate.coordinate_data import CoordinateData

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'CoordinateData' is not used.
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
from model_analyzer.config.generate.model_variant_name_manager import (
ModelVariantNameManager,
)
from model_analyzer.config.generate.search_parameters import SearchParameters
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
from model_analyzer.config.input.config_defaults import DEFAULT_BATCH_SIZES
from model_analyzer.config.run.model_run_config import ModelRunConfig
from model_analyzer.config.run.run_config import RunConfig
from model_analyzer.constants import LOGGER_NAME
from model_analyzer.device.gpu_device import GPUDevice
from model_analyzer.perf_analyzer.perf_config import PerfAnalyzerConfig
from model_analyzer.record.metrics_manager import MetricsManager
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
from model_analyzer.triton.client.client import TritonClient
from model_analyzer.triton.model.model_config import ModelConfig
from model_analyzer.triton.model.model_config_variant import ModelConfigVariant

from .config_generator_interface import ConfigGeneratorInterface
from .generator_utils import GeneratorUtils

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'GeneratorUtils' is not used.

logger = logging.getLogger(LOGGER_NAME)
from copy import deepcopy

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'deepcopy' is not used.


class OptunaRunConfigGenerator(ConfigGeneratorInterface):
"""
Hill climbing algorithm to create RunConfigs
"""

def __init__(
self,
config: ConfigCommandProfile,
gpus: List[GPUDevice],
models: List[ModelProfileSpec],
client: TritonClient,
model_variant_name_manager: ModelVariantNameManager,
search_parameters: SearchParameters,
metrics_manager: MetricsManager,
seed: Optional[int] = 0,
):
"""
Parameters
----------
config: ConfigCommandProfile
Profile configuration information
gpus: List of GPUDevices
models: List of ModelProfileSpec
List of models to profile
client: TritonClient
model_variant_name_manager: ModelVariantNameManager
search_parameters: SearchParameters
The object that handles the users configuration search parameters
"""
self._config = config
self._client = client
self._gpus = gpus
self._models = models
self._search_parameters = search_parameters
self._metrics_manager = metrics_manager

self._model_variant_name_manager = model_variant_name_manager

self._triton_env = BruteRunConfigGenerator.determine_triton_server_env(models)

self._num_models = len(models)
self._last_measurement: Optional[RunConfigMeasurement] = None

self._c_api_mode = config.triton_launch_mode == "c_api"

self._done = False

if seed is not None:
self._sampler = optuna.samplers.TPESampler(seed=seed)
else:
self._sampler = optuna.samplers.TPESampler()

self._study = optuna.create_study(
study_name=self._models[0].model_name(),
direction="maximize",
sampler=self._sampler,
)

def _is_done(self) -> bool:
return self._done

def set_last_results(
self, measurements: List[Optional[RunConfigMeasurement]]
) -> None:
# TODO: TMA-1927: Add support for multi-model
if measurements[0] is not None:
self._last_measurement = measurements[0]
else:
self._last_measurement = None

def get_configs(self) -> Generator[RunConfig, None, None]:
"""
Returns
-------
RunConfig
The next RunConfig generated by this class
"""
default_run_config = self._create_default_run_config()
yield default_run_config
self._default_measurement = self._last_measurement

# TODO: TMA-1884: Need a default + config option for trial number
n_trials = 20
# TODO: TMA-1885: Need an early exit strategy
for _ in range(n_trials):
trial = self._study.ask()
self._create_trial_objectives(trial)
run_config = self._create_objective_based_run_config()
yield run_config
score = self._calculate_score()
self._study.tell(trial, score)

def _create_trial_objectives(self, trial) -> None:
# TODO: TMA-1925: Use SearchParameters here
self._instance_count = trial.suggest_int("instance_count", 1, 8)
self._batch_size = int(2 ** trial.suggest_int("batch_size", 1, 10))

# TODO: TMA-1884: Need an option to choose btw. concurrency formula and optuna searching
self._concurrency = 2 * self._instance_count * self._batch_size
if self._concurrency > 1024:
self._concurrency = 1024

def _create_objective_based_run_config(self) -> RunConfig:
param_combo = self._create_parameter_combo()

# TODO: TMA-1927: Add support for multi-model
run_config = RunConfig(self._triton_env)

model_config_variant = BaseModelConfigGenerator.make_model_config_variant(
param_combo=param_combo,
model=self._models[0],
model_variant_name_manager=self._model_variant_name_manager,
c_api_mode=self._c_api_mode,
)

# TODO: TMA-1927: Add support for multi-model
model_run_config = self._create_model_run_config(
model=self._models[0],
model_config_variant=model_config_variant,
)

run_config.add_model_run_config(model_run_config=model_run_config)

return run_config

def _create_parameter_combo(self) -> Dict[str, Any]:
# TODO: TMA-1925: Use SearchParameters here
param_combo: Dict["str", Any] = {}
param_combo["dynamic_batching"] = {}

# TODO: TMA-1927: Add support for multi-model
kind = "KIND_CPU" if self._models[0].cpu_only() else "KIND_GPU"
param_combo["instance_group"] = [
{
"count": self._instance_count,
"kind": kind,
}
]

param_combo["max_batch_size"] = self._batch_size

return param_combo

def _calculate_score(self) -> float:
if self._last_measurement:
score = self._default_measurement.compare_measurements( # type: ignore
self._last_measurement
)
else:
# TODO: TMA-1927: Figure out the correct value for this (and make it a constant)
score = -1

return score

def _create_default_run_config(self) -> RunConfig:
default_run_config = RunConfig(self._triton_env)
# TODO: TMA-1927: Add support for multi-model
default_model_run_config = self._create_default_model_run_config(
self._models[0]
)
default_run_config.add_model_run_config(default_model_run_config)

return default_run_config

def _create_default_model_run_config(
self, model: ModelProfileSpec
) -> ModelRunConfig:
default_model_config_variant = (
BaseModelConfigGenerator.make_model_config_variant(
param_combo={},
model=model,
model_variant_name_manager=self._model_variant_name_manager,
c_api_mode=self._c_api_mode,
)
)

default_perf_analyzer_config = self._create_default_perf_analyzer_config(
model, default_model_config_variant.model_config
)

default_model_run_config = ModelRunConfig(
model.model_name(),
default_model_config_variant,
default_perf_analyzer_config,
)

return default_model_run_config

def _create_default_perf_analyzer_config(
self, model: ModelProfileSpec, model_config: ModelConfig
) -> PerfAnalyzerConfig:
default_perf_analyzer_config = PerfAnalyzerConfig()
default_perf_analyzer_config.update_config_from_profile_config(
model_config.get_field("name"), self._config
)

default_concurrency = self._calculate_default_concurrency(model_config)

perf_config_params = {
"batch-size": DEFAULT_BATCH_SIZES,
"concurrency-range": default_concurrency,
}
default_perf_analyzer_config.update_config(perf_config_params)

default_perf_analyzer_config.update_config(model.perf_analyzer_flags())

return default_perf_analyzer_config

def _calculate_default_concurrency(self, model_config: ModelConfig) -> int:
default_max_batch_size = model_config.max_batch_size()
default_instance_count = model_config.instance_group_count(
system_gpu_count=len(self._gpus)
)
default_concurrency = 2 * default_max_batch_size * default_instance_count

return default_concurrency

def _create_model_run_config(
self,
model: ModelProfileSpec,
model_config_variant: ModelConfigVariant,
) -> ModelRunConfig:
perf_analyzer_config = self._create_perf_analyzer_config(
model.model_name(), model, self._concurrency
)
model_run_config = ModelRunConfig(
model.model_name(), model_config_variant, perf_analyzer_config
)

return model_run_config

def _create_perf_analyzer_config(
self,
model_name: str,
model: ModelProfileSpec,
concurrency: int,
) -> PerfAnalyzerConfig:
perf_analyzer_config = PerfAnalyzerConfig()

perf_analyzer_config.update_config_from_profile_config(model_name, self._config)

perf_config_params = {"batch-size": 1, "concurrency-range": concurrency}
perf_analyzer_config.update_config(perf_config_params)

perf_analyzer_config.update_config(model.perf_analyzer_flags())
return perf_analyzer_config

def _print_debug_logs(
self, measurements: List[Union[RunConfigMeasurement, None]]
) -> None:
# TODO: TMA-1928
NotImplemented

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ grpcio<=1.59.3 # 1.60.0 has broken L0_ssl_grpc. The client gets corrupted and ca
httplib2>=0.19.0
matplotlib>=3.3.4
numba>=0.51.2
optuna>=3.6.1
pdfkit>=0.6.1
prometheus_client>=0.9.0
protobuf
Expand Down
Loading

0 comments on commit 509a4b5

Please sign in to comment.