diff --git a/onnxscript/tools/benchmark/__init__.py b/onnxscript/tools/benchmark/__init__.py index ccc9d81ed..8f1b6f4d3 100644 --- a/onnxscript/tools/benchmark/__init__.py +++ b/onnxscript/tools/benchmark/__init__.py @@ -5,6 +5,9 @@ from onnxscript.tools.benchmark.benchmark_helpers import ( common_export, get_parsed_args, + make_configs, + make_dataframe_from_benchmark_data, + multi_run, run_inference, run_onnx_inference, ) @@ -12,6 +15,9 @@ __all__ = [ "get_parsed_args", "common_export", + "make_configs", + "multi_run", + "make_dataframe_from_benchmark_data", "run_inference", "run_onnx_inference", ] diff --git a/onnxscript/tools/benchmark/benchmark_helpers.py b/onnxscript/tools/benchmark/benchmark_helpers.py index 36d9084fa..e796a8808 100644 --- a/onnxscript/tools/benchmark/benchmark_helpers.py +++ b/onnxscript/tools/benchmark/benchmark_helpers.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import itertools import multiprocessing import os import platform @@ -195,6 +196,52 @@ def run_benchmark( return data +def measure_discrepancies( + expected: list[tuple[Any, ...]], + outputs: list[tuple[Any, ...]], +) -> tuple[float, float]: + """ + Computes the discrepancies. + + Args: + expected: list of outputs coming from a torch model + outputs: list of outputs coming from an onnx model + + Returns: + max absolute errors, max relative errors + """ + + def _flatten(outputs): + flat = [] + for tensor in outputs: + if isinstance(tensor, tuple): + flat.extend(_flatten(tensor)) + else: + flat.append(tensor) + return tuple(flat) + + abs_errs = [] + rel_errs = [] + for torch_outputs_mixed_types, onnx_outputs in zip(expected, outputs): + torch_outputs = _flatten(torch_outputs_mixed_types) + assert len(torch_outputs) == len( + onnx_outputs + ), f"Length mismatch {len(torch_outputs)} != {len(onnx_outputs)}" + for torch_tensor, onnx_tensor in zip(torch_outputs, onnx_outputs): + assert ( + torch_tensor.dtype == onnx_tensor.dtype + ), f"Type mismatch {torch_tensor.dtype} != {onnx_tensor.dtype}" + assert ( + torch_tensor.shape == onnx_tensor.shape + ), f"Type mismatch {torch_tensor.shape} != {onnx_tensor.shape}" + diff = torch_tensor - onnx_tensor + abs_err = float(diff.abs().max()) + rel_err = float((diff.abs() / torch_tensor).max()) + abs_errs.append(abs_err) + rel_errs.append(rel_err) + return max(abs_errs), max(rel_errs) + + def common_export( model: Any, inputs: Sequence[Any], @@ -620,6 +667,7 @@ def run_onnx_inference( repeat: int = 5, verbose: int = 0, ort_optimize: bool = True, + torch_model: Any | None = None, ) -> dict[str, Any]: """ Runs multiple times the same inference with onnxruntime. @@ -631,6 +679,7 @@ def run_onnx_inference( repeat: number of iterations to repeat verbose: verbosity ort_optimize: enable, disable onnxruntime optimizations + torch_model: if not empty, measure the discrepancies Returns: statistcs @@ -667,16 +716,26 @@ def run_onnx_inference( print(f"[run_inference] created session in {end}") print(f"[run_inference] start {warmup} warmup iterations") + if torch_model: + expected = [ + torch_model(*example_inputs[i % len(example_inputs)]) for i in range(warmup) + ] + + got = [] iterations = [] begin = time.perf_counter() for i in range(warmup): t0 = time.perf_counter() - wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)]) + got.append(wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)])) iterations.append(time.perf_counter() - t0) end = time.perf_counter() - begin stats["warmup"] = warmup stats["warmup_time"] = end / warmup stats["warmup_iter"] = iterations + if torch_model: + abs_err, rel_err = measure_discrepancies(expected, got) + stats["discrepancies_abs"] = abs_err + stats["discrepancies_rel"] = rel_err if verbose: print(f"[run_inference] warmup done in {time.perf_counter() - begin}") @@ -697,3 +756,28 @@ def run_onnx_inference( print(f"[run_inference] measure done in {time.perf_counter() - begin}") return stats + + +def multi_run(kwargs: dict[str, Any]) -> bool: + """Checks if multiple values were sent for one argument.""" + return any(isinstance(v, str) and "," in v for v in kwargs.values()) + + +def make_configs(kwargs: dict[str, Any]) -> list[dict[str, Any]]: + """Creates all the configurations based on the command line arguments.""" + print(kwargs) + args = [] + for k, v in kwargs.items(): + if isinstance(v, str): + args.append([(k, s) for s in v.split(",")]) + else: + args.append([(k, v)]) + configs = list(itertools.product(*args)) + return [dict(c) for c in configs] + + +def make_dataframe_from_benchmark_data(data: list[dict]) -> Any: + """Creates a dataframe from the received data.""" + import pandas + + return pandas.DataFrame(data) diff --git a/onnxscript/tools/benchmark/benchmark_helpers_test.py b/onnxscript/tools/benchmark/benchmark_helpers_test.py new file mode 100644 index 000000000..ec88ffd9e --- /dev/null +++ b/onnxscript/tools/benchmark/benchmark_helpers_test.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import onnxscript.tools.benchmark.benchmark_helpers as bh + + +class BenchmarkHelperTest(unittest.TestCase): + def test_make_configs(self): + value = { + "warmup": 5, + "model": "llama,phi", + "device": "cpu,cuda", + "config": "medium", + "dump_folder": "", + } + self.assertTrue(bh.multi_run(value)) + configs = bh.make_configs(value) + expected = [ + { + "warmup": 5, + "model": "llama", + "device": "cpu", + "config": "medium", + "dump_folder": "", + }, + { + "warmup": 5, + "model": "llama", + "device": "cuda", + "config": "medium", + "dump_folder": "", + }, + { + "warmup": 5, + "model": "phi", + "device": "cpu", + "config": "medium", + "dump_folder": "", + }, + { + "warmup": 5, + "model": "phi", + "device": "cuda", + "config": "medium", + "dump_folder": "", + }, + ] + self.assertEqual(expected, configs) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/tools/benchmark/benchmark_run.py b/onnxscript/tools/benchmark/benchmark_run.py new file mode 100644 index 000000000..abae04b4c --- /dev/null +++ b/onnxscript/tools/benchmark/benchmark_run.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=consider-using-with,import-outside-toplevel +from __future__ import annotations + +import multiprocessing +import os +import platform +import re +import subprocess +import sys + + +class BenchmarkError(RuntimeError): + pass + + +def get_machine() -> dict[str, str | int | float | tuple[int, int]]: + """Returns the machine specification.""" + config: dict[str, str | int | float | tuple[int, int]] = dict( + machine=str(platform.machine()), + processor=str(platform.processor()), + version=str(sys.version), + config=int(multiprocessing.cpu_count()), + executable=str(sys.executable), + ) + try: + import torch.cuda + except ImportError: + return config + + config["has_cuda"] = bool(torch.cuda.is_available()) + if config["has_cuda"]: + config["capability"] = torch.cuda.get_device_capability(0) + config["device_name"] = str(torch.cuda.get_device_name(0)) + return config + + +def _cmd_line(script_name: str, **kwargs: dict[str, str | int | float]) -> list[str]: + args = [sys.executable, "-m", script_name] + for k, v in kwargs.items(): + args.append(f"--{k}") + args.append(str(v)) + return args + + +def _extract_metrics(text: str) -> dict[str, str]: + reg = re.compile(":(.*?),(.*.?);") + res = reg.findall(text) + if len(res) == 0: + return {} + return dict(res) + + +def _make_prefix(script_name: str, index: int) -> str: + name = os.path.splitext(script_name)[0] + return f"{name}_dort_c{index}_" + + +def run_benchmark( + script_name: str, + configs: list[dict[str, str | int | float]], + verbose: int = 0, + stop_if_exception: bool = True, + dort_dump: bool = False, +) -> list[dict[str, str | int | float | tuple[int, int]]]: + """ + Runs a script multiple times and extract information from the output + following the pattern ``:,;``. + + :param script_name: python script to run + :param configs: list of execution to do + :param stop_if_exception: stop if one experiment failed, otherwise continue + :param verbose: use tqdm to follow the progress + :param dort_dump: dump onnx file if dort is used + :return: values + """ + if verbose: + try: + from tqdm import tqdm + + loop = tqdm(configs) + except ImportError: + loop = configs + else: + loop = configs + + data: list[dict[str, str | int | float | tuple[int, int]]] = [] + for i, config in enumerate(loop): + cmd = _cmd_line(script_name, **config) + + if dort_dump: + os.environ["ONNXRT_DUMP_PATH"] = _make_prefix(script_name, i) + else: + os.environ["ONNXRT_DUMP_PATH"] = "" + if verbose > 3: + print(f"[run_benchmark] cmd={cmd if isinstance(cmd, str) else ' '.join(cmd)}") + + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + try: + res = p.communicate(timeout=30) + out, err = res + serr = err.decode("utf-8", errors="ignore") + except subprocess.TimeoutExpired as e: + p.kill() + res = p.communicate() + out, err = res + serr = f"{e}\n:timeout,1;{err.decode('utf-8', errors='ignore')}" + sout = out.decode("utf-8", errors="ignore") + + if "ONNXRuntimeError" in serr or "ONNXRuntimeError" in sout: + if stop_if_exception: # pylint: disable=no-else-raise + raise RuntimeError( + f"Unable to continue with config {config} due to the " + f"following error\n{serr}" + f"\n----OUTPUT--\n{sout}" + ) + + metrics = _extract_metrics(sout) + if len(metrics) == 0: + if stop_if_exception: # pylint: disable=no-else-raise + raise BenchmarkError( + f"Unable (2) to continue with config {config}, no metric was " + f"collected.\n--ERROR--\n{serr}\n--OUTPUT--\n{sout}" + ) + else: + metrics = {} + metrics.update(config) + metrics["ERROR"] = serr + metrics["OUTPUT"] = sout + metrics["CMD"] = f"[{' '.join(cmd)}]" + data.append(metrics) # type: ignore[arg-type] + if verbose > 5: + print("--------------- ERROR") + print(serr) + if verbose >= 10: + print("--------------- OUTPUT") + print(sout) + + return data diff --git a/onnxscript/tools/benchmark/export_model.py b/onnxscript/tools/benchmark/export_model.py index 88d40dc27..b6bbc37fd 100644 --- a/onnxscript/tools/benchmark/export_model.py +++ b/onnxscript/tools/benchmark/export_model.py @@ -19,6 +19,10 @@ def main(args=None): This script can be used to quickly evaluate the improvment made by a pattern optimization for a particular model. + If one value contains ",", the script understand multiple commands + must be run. It computes all the possible configurations. + In that case, it produces a csv file (if output_data is not empty) with all the results. + Example with a large phi model:: python -m onnxscript.tools.benchmark.export_model --model phi --device cuda --config large --num_hidden_layers=6 --dtype=float32 --dynamic=0 --verbose=1 --exporter=dynamo @@ -50,130 +54,153 @@ def main(args=None): ), implementation=("eager", "eager or sdpa"), memory_peak=(0, "measure the memory peak during conversion"), + output_data=( + "export_model.csv", + "produces a csv file with the data if multiple configurations are tested", + ), new_args=args, ) - - print("-------------------") - print("[export_model]") - pprint.pprint(kwargs) - print("-------------------") - - # Import is delayed so that help is being display faster (without having to import heavy packages). - import onnxscript.tools - import onnxscript.tools.memory_peak - import onnxscript.tools.transformers_models - - print( - f"[export_model] create the model and inputs for {kwargs['model']!r} and config {kwargs['config']!r}" - ) - begin = time.perf_counter() - model, example_inputs, dynamic_shapes = ( - onnxscript.tools.transformers_models.get_model_and_inputs( - warmup=kwargs["warmup"], - repeat=kwargs["repeat"], - model=kwargs["model"], - config=kwargs["config"], - dynamic_shapes=kwargs["dynamic"], - device=kwargs["device"], - num_hidden_layers=kwargs["num_hidden_layers"], - with_mask=kwargs["with_mask"], - implementation=kwargs["implementation"], - dtype=kwargs["dtype"], + if onnxscript.tools.benchmark.multi_run(kwargs): + import onnxscript.tools.benchmark.benchmark_run + + configs = onnxscript.tools.benchmark.make_configs(kwargs) + data = onnxscript.tools.benchmark.benchmark_run.run_benchmark( + "onnxscript.tools.benchmark.export_model", + configs, + kwargs["verbose"], + stop_if_exception=False, ) - ) - print(f"[export_model] model created in {time.perf_counter() - begin}") - if kwargs["dynamic"]: - print(f"[export_model] dynamic_shapes={dynamic_shapes}") - msg = [tuple(i.shape for i in inp) for inp in example_inputs] - print(f"[export_model] input_shapes={msg}") - conversion: dict[str, Any] = {} - memory_stats: dict[str, float] = {} - - if kwargs["exporter"] == "eager": - print("[export_model] start benchmark") - begin = time.perf_counter() - result = onnxscript.tools.benchmark.run_inference( - model, - example_inputs, - warmup=kwargs["warmup"], - repeat=kwargs["repeat"], - verbose=kwargs["verbose"], - ) - print(f"[export_model] benchmark done in {time.perf_counter() - begin}") + if kwargs["verbose"] > 2: + pprint.pprint(data if kwargs["verbose"] > 3 else data[:2]) + if kwargs["output_data"]: + df = onnxscript.tools.benchmark.make_dataframe_from_benchmark_data(data) + df.to_csv(kwargs["output_data"], index=False) + df.to_excel(kwargs["output_data"] + ".xlsx", index=False) + if kwargs["verbose"]: + print(df) else: + print("-------------------") + print("[export_model]") + pprint.pprint(kwargs) + print("-------------------") + + # Import is delayed so that help is being display faster (without having to import heavy packages). + import onnxscript.tools + import onnxscript.tools.memory_peak + import onnxscript.tools.transformers_models + print( - f"[export_model] export to onnx with exporter={kwargs['exporter']!r} " - f"and optimization={kwargs['optimization']!r}" + f"[export_model] create the model and inputs for {kwargs['model']!r} and config {kwargs['config']!r}" ) begin = time.perf_counter() - if kwargs["optimization"]: - m = hashlib.sha256() - m.update(kwargs["optimization"].encode()) - so = m.hexdigest()[:5] - else: - so = "" - name = "_".join( - [ - kwargs["model"], - kwargs["exporter"], - "dynamic" if kwargs["dynamic"] else "static", - kwargs["dtype"].replace("float", "fp"), - kwargs["device"], - kwargs["config"], - f"h{kwargs['num_hidden_layers']}", - so, - ], - ) - filename = f"em_{name}.onnx" - - memory_session = ( - onnxscript.tools.memory_peak.start_spying_on(cuda=kwargs["device"] == "cuda") - if kwargs["memory_peak"] - else None - ) - print(f"[export_model] start memory peak monitoring {memory_session}") - proto = onnxscript.tools.benchmark.common_export( - model=model, - inputs=example_inputs[0], - exporter=kwargs["exporter"], - target_opset=kwargs["target_opset"], - folder=kwargs["dump_folder"], - filename=filename, - dynamic_shapes=dynamic_shapes if kwargs["dynamic"] else None, - optimization=kwargs["optimization"], - verbose=kwargs["verbose"], - stats=conversion, + model, example_inputs, dynamic_shapes = ( + onnxscript.tools.transformers_models.get_model_and_inputs( + warmup=kwargs["warmup"], + repeat=kwargs["repeat"], + model=kwargs["model"], + config=kwargs["config"], + dynamic_shapes=kwargs["dynamic"], + device=kwargs["device"], + num_hidden_layers=kwargs["num_hidden_layers"], + with_mask=kwargs["with_mask"], + implementation=kwargs["implementation"], + dtype=kwargs["dtype"], + ) ) - print(f"[export_model] export to onnx done in {time.perf_counter() - begin}") - if memory_session is not None: - memory_results = memory_session.stop() - print(f"[export_model] ends memory monitoring {memory_results}") - memory_stats = onnxscript.tools.memory_peak.flatten( - memory_results, prefix="memory_" + print(f"[export_model] model created in {time.perf_counter() - begin}") + if kwargs["dynamic"]: + print(f"[export_model] dynamic_shapes={dynamic_shapes}") + msg = [tuple(i.shape for i in inp) for inp in example_inputs] + print(f"[export_model] input_shapes={msg}") + conversion: dict[str, Any] = {} + memory_stats: dict[str, float] = {} + + if kwargs["exporter"] == "eager": + print("[export_model] start benchmark") + begin = time.perf_counter() + result = onnxscript.tools.benchmark.run_inference( + model, + example_inputs, + warmup=kwargs["warmup"], + repeat=kwargs["repeat"], + verbose=kwargs["verbose"], ) + print(f"[export_model] benchmark done in {time.perf_counter() - begin}") else: - memory_stats = {} - - result = onnxscript.tools.benchmark.run_onnx_inference( - proto, - example_inputs, - warmup=kwargs["warmup"], - repeat=kwargs["repeat"], - verbose=kwargs["verbose"], - ort_optimize=kwargs["ort_optimize"], - ) + print( + f"[export_model] export to onnx with exporter={kwargs['exporter']!r} " + f"and optimization={kwargs['optimization']!r}" + ) + begin = time.perf_counter() + if kwargs["optimization"]: + m = hashlib.sha256() + m.update(kwargs["optimization"].encode()) + so = m.hexdigest()[:5] + else: + so = "" + name = "_".join( + [ + kwargs["model"], + kwargs["exporter"], + "dynamic" if kwargs["dynamic"] else "static", + kwargs["dtype"].replace("float", "fp"), + kwargs["device"], + kwargs["config"], + f"h{kwargs['num_hidden_layers']}", + so, + ], + ) + filename = f"em_{name}.onnx" - print("[export_model] end") - print("------------------------------") - for k, v in sorted(kwargs.items()): - print(f":{k},{v};") - for k, v in sorted(conversion.items()): - print(f":{k},{v};") - if memory_stats: - for k, v in memory_stats.items(): + memory_session = ( + onnxscript.tools.memory_peak.start_spying_on(cuda=kwargs["device"] == "cuda") + if kwargs["memory_peak"] + else None + ) + print(f"[export_model] start memory peak monitoring {memory_session}") + proto = onnxscript.tools.benchmark.common_export( + model=model, + inputs=example_inputs[0], + exporter=kwargs["exporter"], + target_opset=kwargs["target_opset"], + folder=kwargs["dump_folder"], + filename=filename, + dynamic_shapes=dynamic_shapes if kwargs["dynamic"] else None, + optimization=kwargs["optimization"], + verbose=kwargs["verbose"], + stats=conversion, + ) + print(f"[export_model] export to onnx done in {time.perf_counter() - begin}") + if memory_session is not None: + memory_results = memory_session.stop() + print(f"[export_model] ends memory monitoring {memory_results}") + memory_stats = onnxscript.tools.memory_peak.flatten( + memory_results, prefix="memory_" + ) + else: + memory_stats = {} + + result = onnxscript.tools.benchmark.run_onnx_inference( + proto, + example_inputs, + warmup=kwargs["warmup"], + repeat=kwargs["repeat"], + verbose=kwargs["verbose"], + ort_optimize=kwargs["ort_optimize"], + torch_model=model, + ) + + print("[export_model] end") + print("------------------------------") + for k, v in sorted(kwargs.items()): + print(f":{k},{v};") + for k, v in sorted(conversion.items()): + print(f":{k},{v};") + if memory_stats: + for k, v in memory_stats.items(): + print(f":{k},{v};") + for k, v in sorted(result.items()): print(f":{k},{v};") - for k, v in sorted(result.items()): - print(f":{k},{v};") if __name__ == "__main__": diff --git a/onnxscript/tools/transformers_models/llama_test.py b/onnxscript/tools/transformers_models/llama_test.py index 858e46447..ea4844476 100644 --- a/onnxscript/tools/transformers_models/llama_test.py +++ b/onnxscript/tools/transformers_models/llama_test.py @@ -136,7 +136,7 @@ def test_llama_dort_static(self): expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) - torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1.0e-5, rtol=1e-5) if __name__ == "__main__": diff --git a/tools/function_rewriter_testing/function_unittest_producer.py b/tools/function_rewriter_testing/function_unittest_producer.py index fc94adaa0..b2d484531 100644 --- a/tools/function_rewriter_testing/function_unittest_producer.py +++ b/tools/function_rewriter_testing/function_unittest_producer.py @@ -16,7 +16,6 @@ import logging import os import sys -from typing import Dict, List, Tuple import numpy as np import onnx @@ -73,14 +72,11 @@ def visit_model(self, model: onnx.ModelProto) -> None: super().visit_model(model) -FunctionMetaDict = Dict[Tuple[str, str], Tuple[List[str], List[str]]] - - class TargetFunctionMetaVisitor(visitor.ProtoVisitorCore): def __init__(self, function_keyword): self.function_keyword = function_keyword # Map from (domain, name) to (actual_input_names, actual_output_names) - self.function_meta: FunctionMetaDict = {} + self.function_meta: dict[tuple[str, str], tuple[list[str], list[str]]] = {} self._functions = {} super().__init__()