Skip to content

Commit

Permalink
Refactoring model parameter setting
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed Oct 13, 2023
1 parent a27e31e commit 0c69e03
Showing 1 changed file with 30 additions and 60 deletions.
90 changes: 30 additions & 60 deletions model_analyzer/config/input/config_command_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse
import logging
import os
from typing import Dict

import numba.cuda
import psutil
Expand Down Expand Up @@ -1502,66 +1503,27 @@ def _autofill_values(self):
}
else:
new_model["parameters"] = {}
if "batch_sizes" in model.parameters():
new_model["parameters"].update(
{"batch_sizes": model.parameters()["batch_sizes"]}
)
else:
new_model["parameters"].update({"batch_sizes": self.batch_sizes})

if "concurrency" in model.parameters():
new_model["parameters"].update(
{"concurrency": model.parameters()["concurrency"]}
)
else:
new_model["parameters"].update({"concurrency": self.concurrency})

if "periodic_concurrency" in model.parameters():
new_model["parameters"].update(
{
"periodic_concurrency": model.parameters()[
"periodic_concurrency"
]
}
)
else:
new_model["parameters"].update(
{"periodic_concurrency": self.periodic_concurrency}
)

if "request_rate" in model.parameters():
new_model["parameters"].update(
{"request_rate": model.parameters()["request_rate"]}
)
else:
new_model["parameters"].update({"request_rate": self.request_rate})

if "request_period" in model.parameters():
new_model["parameters"].update(
{"request_period": model.parameters()["request_period"]}
)
else:
new_model["parameters"].update(
{"request_period": self.request_rate}
)

if "text_input_length" in model.parameters():
new_model["parameters"].update(
{"text_input_length": model.parameters()["text_input_length"]}
)
else:
new_model["parameters"].update(
{"text_input_length": self.text_input_length}
)

if "max_token_count" in model.parameters():
new_model["max_token_count"].update(
{"max_token_count": model.parameters()["max_token_count"]}
)
else:
new_model["parameters"].update(
{"max_token_count": self.text_input_length}
)
new_model["parameters"].update(
self._set_model_parameter(model, "batch_sizes")
)
new_model["parameters"].update(
self._set_model_parameter(model, "concurrency")
)
new_model["parameters"].update(
self._set_model_parameter(model, "periodic_concurrency")
)
new_model["parameters"].update(
self._set_model_parameter(model, "request_rate")
)
new_model["parameters"].update(
self._set_model_parameter(model, "request_period")
)
new_model["parameters"].update(
self._set_model_parameter(model, "max_token_count")
)
new_model["parameters"].update(
self._set_model_parameter(model, "text_input_length")
)

if (
new_model["parameters"]["request_rate"]
Expand Down Expand Up @@ -1604,6 +1566,14 @@ def _autofill_values(self):
new_profile_models[model.model_name()] = new_model
self._fields["profile_models"].set_value(new_profile_models)

def _set_model_parameter(
self, model: ConfigModelProfileSpec, parameter_name: str
) -> Dict:
if parameter_name in model.parameters():
return {parameter_name: model.parameters()[parameter_name]}
else:
return {parameter_name: getattr(self, parameter_name)}

def _using_request_rate(self) -> bool:
if self.request_rate or self.request_rate_search_enable:
return True
Expand Down

0 comments on commit 0c69e03

Please sign in to comment.