-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
Signed-off-by: Xavier Dupre <[email protected]>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
import unittest | ||
|
||
import onnxscript.tools.benchmark.benchmark_helpers as bh | ||
|
||
|
||
class BenchmarkHelperTest(unittest.TestCase): | ||
def test_make_configs(self): | ||
value = { | ||
"warmup": 5, | ||
"model": "llama,phi", | ||
"device": "cpu,cuda", | ||
"config": "medium", | ||
"dump_folder": "", | ||
} | ||
self.assertTrue(bh.multi_run(value)) | ||
configs = bh.make_configs(value) | ||
expected = [ | ||
{ | ||
"warmup": 5, | ||
"model": "llama", | ||
"device": "cpu", | ||
"config": "medium", | ||
"dump_folder": "", | ||
}, | ||
{ | ||
"warmup": 5, | ||
"model": "llama", | ||
"device": "cuda", | ||
"config": "medium", | ||
"dump_folder": "", | ||
}, | ||
{ | ||
"warmup": 5, | ||
"model": "phi", | ||
"device": "cpu", | ||
"config": "medium", | ||
"dump_folder": "", | ||
}, | ||
{ | ||
"warmup": 5, | ||
"model": "phi", | ||
"device": "cuda", | ||
"config": "medium", | ||
"dump_folder": "", | ||
}, | ||
] | ||
self.assertEqual(expected, configs) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main(verbosity=2) | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import multiprocessing | ||
import os | ||
import platform | ||
import re | ||
import subprocess | ||
import sys | ||
|
||
|
||
class BenchmarkError(RuntimeError): | ||
pass | ||
|
||
|
||
def get_machine() -> dict[str, str | int | float | tuple[int, int]]: | ||
"""Returns the machine specification.""" | ||
cpu: dict[str, str | int | float | tuple[int, int]] = dict( | ||
machine=str(platform.machine()), | ||
processor=str(platform.processor()), | ||
version=str(sys.version), | ||
cpu=int(multiprocessing.cpu_count()), | ||
executable=str(sys.executable), | ||
) | ||
try: | ||
import torch.cuda | ||
Check notice Code scanning / lintrunner PYLINT/C0415 Note test
Import outside toplevel (torch.cuda) (import-outside-toplevel)
See import-outside-toplevel. To disable, use # pylint: disable=import-outside-toplevel |
||
except ImportError: | ||
return cpu | ||
|
||
cpu["has_cuda"] = bool(torch.cuda.is_available()) | ||
if cpu["has_cuda"]: | ||
cpu["capability"] = torch.cuda.get_device_capability(0) | ||
cpu["device_name"] = str(torch.cuda.get_device_name(0)) | ||
return cpu | ||
|
||
|
||
def _cmd_line(script_name: str, **kwargs: dict[str, str | int | float]) -> list[str]: | ||
args = [sys.executable, "-m", script_name] | ||
for k, v in kwargs.items(): | ||
args.append(f"--{k}") | ||
args.append(str(v)) | ||
return args | ||
|
||
|
||
def _extract_metrics(text: str) -> dict[str, str]: | ||
reg = re.compile(":(.*?),(.*.?);") | ||
res = reg.findall(text) | ||
if len(res) == 0: | ||
return {} | ||
return dict(res) | ||
|
||
|
||
def _make_prefix(script_name: str, index: int) -> str: | ||
name = os.path.splitext(script_name)[0] | ||
return f"{name}_dort_c{index}_" | ||
|
||
|
||
def run_benchmark( | ||
script_name: str, | ||
configs: list[dict[str, str | int | float]], | ||
verbose: int = 0, | ||
stop_if_exception: bool = True, | ||
dort_dump: bool = False, | ||
) -> list[dict[str, str | int | float | tuple[int, int]]]: | ||
""" | ||
Runs a script multiple times and extract information from the output | ||
following the pattern ``:<metric>,<value>;``. | ||
:param script_name: python script to run | ||
:param configs: list of execution to do | ||
:param stop_if_exception: stop if one experiment failed, otherwise continue | ||
:param verbose: use tqdm to follow the progress | ||
:param dort_dump: dump onnx file if dort is used | ||
:return: values | ||
""" | ||
if verbose: | ||
from tqdm import tqdm | ||
Check notice Code scanning / lintrunner PYLINT/C0415 Note test
Import outside toplevel (tqdm.tqdm) (import-outside-toplevel)
See import-outside-toplevel. To disable, use # pylint: disable=import-outside-toplevel |
||
|
||
loop = tqdm(configs) | ||
else: | ||
loop = configs | ||
|
||
data: list[dict[str, str | int | float | tuple[int, int]]] = [] | ||
for i, config in enumerate(loop): | ||
cmd = _cmd_line(script_name, **config) | ||
|
||
if dort_dump: | ||
os.environ["ONNXRT_DUMP_PATH"] = _make_prefix(script_name, i) | ||
else: | ||
os.environ["ONNXRT_DUMP_PATH"] = "" | ||
if verbose > 3: | ||
print(f"[run_benchmark] cmd={cmd if isinstance(cmd, str) else ' '.join(cmd)}") | ||
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | ||
Check notice Code scanning / lintrunner PYLINT/R1732 Note test
Consider using 'with' for resource-allocating operations (consider-using-with)
See consider-using-with. To disable, use # pylint: disable=consider-using-with |
||
res = p.communicate() | ||
out, err = res | ||
sout = out.decode("utf-8", errors="ignore") | ||
serr = err.decode("utf-8", errors="ignore") | ||
|
||
if "ONNXRuntimeError" in serr or "ONNXRuntimeError" in sout: | ||
if stop_if_exception: | ||
raise RuntimeError( | ||
f"Unable to continue with config {config} due to the " | ||
f"following error\n{serr}" | ||
f"\n----OUTPUT--\n{sout}" | ||
) | ||
|
||
metrics = _extract_metrics(sout) | ||
if len(metrics) == 0: | ||
if stop_if_exception: | ||
Check notice Code scanning / lintrunner PYLINT/R1720 Note test
Unnecessary "else" after "raise", remove the "else" and de-indent the code inside it (no-else-raise)
See no-else-raise. To disable, use # pylint: disable=no-else-raise |
||
raise BenchmarkError( | ||
f"Unable (2) to continue with config {config}, no metric was " | ||
f"collected.\n--ERROR--\n{serr}\n--OUTPUT--\n{sout}" | ||
) | ||
else: | ||
metrics = {} | ||
metrics.update(config) | ||
metrics["ERROR"] = serr | ||
metrics["OUTPUT"] = sout | ||
metrics["CMD"] = f"[{' '.join(cmd)}]" | ||
data.append(metrics) | ||
Check failure Code scanning / lintrunner MYPY/arg-type Error test
Argument 1 to "append" of "list" has incompatible type "dict[str, str]"; expected "dict[str, str | int | float | tuple[int, int]]"
To disable, use # type: ignore[arg-type]
|
||
if verbose > 5: | ||
print("--------------- ERROR") | ||
print(serr) | ||
if verbose >= 10: | ||
print("--------------- OUTPUT") | ||
print(sout) | ||
|
||
return data | ||