Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bench] Add code to run multiple command lines and export the result in a csv file #1641

Merged
merged 17 commits into from
Jul 3, 2024
6 changes: 6 additions & 0 deletions onnxscript/tools/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
from onnxscript.tools.benchmark.benchmark_helpers import (
common_export,
get_parsed_args,
make_configs,
make_dataframe_from_benchmark_data,
multi_run,
run_inference,
run_onnx_inference,
)

__all__ = [
"get_parsed_args",
"common_export",
"make_configs",
"multi_run",
"make_dataframe_from_benchmark_data",
"run_inference",
"run_onnx_inference",
]
85 changes: 84 additions & 1 deletion onnxscript/tools/benchmark/benchmark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import argparse
import itertools
import multiprocessing
import os
import platform
Expand Down Expand Up @@ -195,6 +196,51 @@
return data


def measure_discrepancies(
expected: list[tuple(Any, ...)], outputs: list[tuple(Any, ...)]
Fixed Show fixed Hide fixed
) -> tuple[float, float]:
"""
Computes the discrepancies.

Args:
expected: list of outputs coming from a torch model
outputs: list of outputs coming from an onnx model

Returns:
max absole errors, max relative errors
xadupre marked this conversation as resolved.
Show resolved Hide resolved
"""

def _flatten(outputs):
flat = []

Check warning on line 214 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L213-L214

Added lines #L213 - L214 were not covered by tests
for tensor in outputs:
if isinstance(tensor, tuple):
flat.extend(_flatten(tensor))

Check warning on line 217 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L217

Added line #L217 was not covered by tests
else:
flat.append(tensor)
return tuple(flat)

Check warning on line 220 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L219-L220

Added lines #L219 - L220 were not covered by tests

abs_errs = []
rel_errs = []

Check warning on line 223 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L222-L223

Added lines #L222 - L223 were not covered by tests
for torch_outputs_mixed_types, onnx_outputs in zip(expected, outputs):
torch_outputs = _flatten(torch_outputs_mixed_types)
assert len(torch_outputs) == len(

Check warning on line 226 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L225-L226

Added lines #L225 - L226 were not covered by tests
onnx_outputs
), f"Length mismatch {len(torch_outputs)} != {len(onnx_outputs)}"
for torch_tensor, onnx_tensor in zip(torch_outputs, onnx_outputs):
assert (

Check warning on line 230 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L230

Added line #L230 was not covered by tests
torch_tensor.dtype == onnx_tensor.dtype
), f"Type mismatch {torch_tensor.dtype} != {onnx_tensor.dtype}"
assert (

Check warning on line 233 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L233

Added line #L233 was not covered by tests
torch_tensor.shape == onnx_tensor.shape
), f"Type mismatch {torch_tensor.shape} != {onnx_tensor.shape}"
diff = torch_tensor - onnx_tensor
abs_err = float(diff.abs().max())
rel_err = float((diff.abs() / torch_tensor).max())
abs_errs.append(abs_err)
rel_errs.append(rel_err)
return max(abs_errs), max(rel_errs)

Check warning on line 241 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L236-L241

Added lines #L236 - L241 were not covered by tests


def common_export(
model: Any,
inputs: Sequence[Any],
Expand Down Expand Up @@ -620,6 +666,7 @@
repeat: int = 5,
verbose: int = 0,
ort_optimize: bool = True,
torch_model: Any | None = None,
) -> dict[str, Any]:
"""
Runs multiple times the same inference with onnxruntime.
Expand All @@ -631,6 +678,7 @@
repeat: number of iterations to repeat
verbose: verbosity
ort_optimize: enable, disable onnxruntime optimizations
torch_model: if not empty, measure the discrepancies

Returns:
statistcs
Expand Down Expand Up @@ -667,16 +715,26 @@
print(f"[run_inference] created session in {end}")
print(f"[run_inference] start {warmup} warmup iterations")

if torch_model:
expected = [
torch_model(*example_inputs[i % len(example_inputs)]) for i in range(warmup)
]

got = []

Check warning on line 723 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L723

Added line #L723 was not covered by tests
iterations = []
begin = time.perf_counter()
for i in range(warmup):
t0 = time.perf_counter()
wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)])
got.append(wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)]))

Check warning on line 728 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L728

Added line #L728 was not covered by tests
iterations.append(time.perf_counter() - t0)
end = time.perf_counter() - begin
stats["warmup"] = warmup
stats["warmup_time"] = end / warmup
stats["warmup_iter"] = iterations
if torch_model:
abs_err, rel_err = measure_discrepancies(expected, got)
stats["discrepancies_abs"] = abs_err
stats["discrepancies_rel"] = rel_err

Check warning on line 737 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L735-L737

Added lines #L735 - L737 were not covered by tests

if verbose:
print(f"[run_inference] warmup done in {time.perf_counter() - begin}")
Expand All @@ -697,3 +755,28 @@
print(f"[run_inference] measure done in {time.perf_counter() - begin}")

return stats


def multi_run(kwargs: dict[str, Any]) -> bool:
"""Checks if multiple values were sent for one argument."""
return any(isinstance(v, str) and "," in v for v in kwargs.values())


def make_configs(kwargs: dict[str, Any]) -> list[dict[str, Any]]:
"""Creates all the configurations based on the command line arguments."""
print(kwargs)
args = []
for k, v in kwargs.items():
if isinstance(v, str):
args.append([(k, s) for s in v.split(",")])
else:
args.append([(k, v)])
configs = list(itertools.product(*args))
return [dict(c) for c in configs]


def make_dataframe_from_benchmark_data(data: list[dict]) -> Any:
"""Creates a dataframe from the received data."""
import pandas

Check warning on line 780 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L780

Added line #L780 was not covered by tests

return pandas.DataFrame(data)

Check warning on line 782 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L782

Added line #L782 was not covered by tests
53 changes: 53 additions & 0 deletions onnxscript/tools/benchmark/benchmark_helpers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest

import onnxscript.tools.benchmark.benchmark_helpers as bh


class BenchmarkHelperTest(unittest.TestCase):
def test_make_configs(self):
value = {
"warmup": 5,
"model": "llama,phi",
"device": "cpu,cuda",
"config": "medium",
"dump_folder": "",
}
self.assertTrue(bh.multi_run(value))
configs = bh.make_configs(value)
expected = [
{
"warmup": 5,
"model": "llama",
"device": "cpu",
"config": "medium",
"dump_folder": "",
},
{
"warmup": 5,
"model": "llama",
"device": "cuda",
"config": "medium",
"dump_folder": "",
},
{
"warmup": 5,
"model": "phi",
"device": "cpu",
"config": "medium",
"dump_folder": "",
},
{
"warmup": 5,
"model": "phi",
"device": "cuda",
"config": "medium",
"dump_folder": "",
},
]
self.assertEqual(expected, configs)


if __name__ == "__main__":
unittest.main(verbosity=2)

Check warning on line 53 in onnxscript/tools/benchmark/benchmark_helpers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers_test.py#L53

Added line #L53 was not covered by tests
140 changes: 140 additions & 0 deletions onnxscript/tools/benchmark/benchmark_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=consider-using-with,import-outside-toplevel
from __future__ import annotations

import multiprocessing
import os
import platform
import re
import subprocess
import sys


class BenchmarkError(RuntimeError):
pass


def get_machine() -> dict[str, str | int | float | tuple[int, int]]:
"""Returns the machine specification."""
cpu: dict[str, str | int | float | tuple[int, int]] = dict(

Check warning on line 20 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L20

Added line #L20 was not covered by tests
xadupre marked this conversation as resolved.
Show resolved Hide resolved
machine=str(platform.machine()),
processor=str(platform.processor()),
version=str(sys.version),
cpu=int(multiprocessing.cpu_count()),
executable=str(sys.executable),
)
try:
import torch.cuda
Fixed Show fixed Hide fixed
except ImportError:
return cpu

Check warning on line 30 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L27-L30

Added lines #L27 - L30 were not covered by tests

cpu["has_cuda"] = bool(torch.cuda.is_available())

Check warning on line 32 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L32

Added line #L32 was not covered by tests
if cpu["has_cuda"]:
cpu["capability"] = torch.cuda.get_device_capability(0)
cpu["device_name"] = str(torch.cuda.get_device_name(0))
return cpu

Check warning on line 36 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L34-L36

Added lines #L34 - L36 were not covered by tests


def _cmd_line(script_name: str, **kwargs: dict[str, str | int | float]) -> list[str]:
args = [sys.executable, "-m", script_name]

Check warning on line 40 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L40

Added line #L40 was not covered by tests
for k, v in kwargs.items():
args.append(f"--{k}")
args.append(str(v))
return args

Check warning on line 44 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L42-L44

Added lines #L42 - L44 were not covered by tests


def _extract_metrics(text: str) -> dict[str, str]:
reg = re.compile(":(.*?),(.*.?);")
res = reg.findall(text)

Check warning on line 49 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L48-L49

Added lines #L48 - L49 were not covered by tests
if len(res) == 0:
return {}
return dict(res)

Check warning on line 52 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L51-L52

Added lines #L51 - L52 were not covered by tests


def _make_prefix(script_name: str, index: int) -> str:
name = os.path.splitext(script_name)[0]
return f"{name}_dort_c{index}_"

Check warning on line 57 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L56-L57

Added lines #L56 - L57 were not covered by tests


def run_benchmark(
script_name: str,
configs: list[dict[str, str | int | float]],
verbose: int = 0,
stop_if_exception: bool = True,
dort_dump: bool = False,
) -> list[dict[str, str | int | float | tuple[int, int]]]:
"""
Runs a script multiple times and extract information from the output
following the pattern ``:<metric>,<value>;``.

:param script_name: python script to run
:param configs: list of execution to do
:param stop_if_exception: stop if one experiment failed, otherwise continue
:param verbose: use tqdm to follow the progress
:param dort_dump: dump onnx file if dort is used
:return: values
"""
if verbose:
try:
from tqdm import tqdm

Check warning on line 80 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L79-L80

Added lines #L79 - L80 were not covered by tests

loop = tqdm(configs)
except ImportError:
loop = configs

Check warning on line 84 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L82-L84

Added lines #L82 - L84 were not covered by tests
else:
loop = configs

Check warning on line 86 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L86

Added line #L86 was not covered by tests

data: list[dict[str, str | int | float | tuple[int, int]]] = []

Check warning on line 88 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L88

Added line #L88 was not covered by tests
for i, config in enumerate(loop):
cmd = _cmd_line(script_name, **config)

Check warning on line 90 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L90

Added line #L90 was not covered by tests

if dort_dump:
os.environ["ONNXRT_DUMP_PATH"] = _make_prefix(script_name, i)

Check warning on line 93 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L93

Added line #L93 was not covered by tests
else:
os.environ["ONNXRT_DUMP_PATH"] = ""

Check warning on line 95 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L95

Added line #L95 was not covered by tests
if verbose > 3:
print(f"[run_benchmark] cmd={cmd if isinstance(cmd, str) else ' '.join(cmd)}")

Check warning on line 97 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L97

Added line #L97 was not covered by tests

p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
Fixed Show fixed Hide fixed
try:
res = p.communicate(timeout=30)
out, err = res
serr = err.decode("utf-8", errors="ignore")
except subprocess.TimeoutExpired as e:
p.kill()
res = p.communicate()
out, err = res
serr = f"{e}\n:timeout,1;{err.decode('utf-8', errors='ignore')}"
sout = out.decode("utf-8", errors="ignore")

Check warning on line 109 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L99-L109

Added lines #L99 - L109 were not covered by tests

if "ONNXRuntimeError" in serr or "ONNXRuntimeError" in sout:
if stop_if_exception: # pylint: disable=no-else-raise
raise RuntimeError(

Check warning on line 113 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L113

Added line #L113 was not covered by tests
f"Unable to continue with config {config} due to the "
f"following error\n{serr}"
f"\n----OUTPUT--\n{sout}"
)

metrics = _extract_metrics(sout)

Check warning on line 119 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L119

Added line #L119 was not covered by tests
if len(metrics) == 0:
if stop_if_exception: # pylint: disable=no-else-raise
raise BenchmarkError(

Check warning on line 122 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L122

Added line #L122 was not covered by tests
f"Unable (2) to continue with config {config}, no metric was "
f"collected.\n--ERROR--\n{serr}\n--OUTPUT--\n{sout}"
)
else:
metrics = {}
metrics.update(config)
metrics["ERROR"] = serr
metrics["OUTPUT"] = sout
metrics["CMD"] = f"[{' '.join(cmd)}]"
data.append(metrics) # type: ignore[arg-type]

Check warning on line 132 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L127-L132

Added lines #L127 - L132 were not covered by tests
if verbose > 5:
print("--------------- ERROR")
print(serr)

Check warning on line 135 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L134-L135

Added lines #L134 - L135 were not covered by tests
if verbose >= 10:
print("--------------- OUTPUT")
print(sout)

Check warning on line 138 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L137-L138

Added lines #L137 - L138 were not covered by tests

return data

Check warning on line 140 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L140

Added line #L140 was not covered by tests
Loading
Loading