Skip to content

Commit

Permalink
add discrepancies
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre committed Jun 24, 2024
1 parent 8d5b411 commit 5735461
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
59 changes: 58 additions & 1 deletion onnxscript/tools/benchmark/benchmark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,51 @@ def run_benchmark(
return data


def measure_discrepancies(
expected: list[tuple(Any, ...)], outputs: list[tuple(Any, ...)]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error test

Invalid type comment or annotation To disable, use # type: ignore[valid-type]
) -> tuple[float, float]:
"""
Computes the discrepancies.
Args:
expected: list of outputs coming from a torch model
outputs: list of outputs coming from an onnx model
Returns:
max absole errors, max relative errors
"""

def _flatten(outputs):
flat = []

Check warning on line 214 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L213-L214

Added lines #L213 - L214 were not covered by tests
for tensor in outputs:
if isinstance(tensor, tuple):
flat.extend(_flatten(tensor))

Check warning on line 217 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L217

Added line #L217 was not covered by tests
else:
flat.append(tensor)
return tuple(flat)

Check warning on line 220 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L219-L220

Added lines #L219 - L220 were not covered by tests

abs_errs = []
rel_errs = []

Check warning on line 223 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L222-L223

Added lines #L222 - L223 were not covered by tests
for torch_outputs_mixed_types, onnx_outputs in zip(expected, outputs):
torch_outputs = _flatten(torch_outputs_mixed_types)
assert len(torch_outputs) == len(

Check warning on line 226 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L225-L226

Added lines #L225 - L226 were not covered by tests
onnx_outputs
), f"Length mismatch {len(torch_outputs)} != {len(onnx_outputs)}"
for torch_tensor, onnx_tensor in zip(torch_outputs, onnx_outputs):
assert (

Check warning on line 230 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L230

Added line #L230 was not covered by tests
torch_tensor.dtype == onnx_tensor.dtype
), f"Type mismatch {torch_tensor.dtype} != {onnx_tensor.dtype}"
assert (

Check warning on line 233 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L233

Added line #L233 was not covered by tests
torch_tensor.shape == onnx_tensor.shape
), f"Type mismatch {torch_tensor.shape} != {onnx_tensor.shape}"
diff = torch_tensor - onnx_tensor
abs_err = float(diff.abs().max())
rel_err = float((diff.abs() / torch_tensor).max())
abs_errs.append(abs_err)
rel_errs.append(rel_err)
return max(abs_errs), max(rel_errs)

Check warning on line 241 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L236-L241

Added lines #L236 - L241 were not covered by tests


def common_export(
model: Any,
inputs: Sequence[Any],
Expand Down Expand Up @@ -621,6 +666,7 @@ def run_onnx_inference(
repeat: int = 5,
verbose: int = 0,
ort_optimize: bool = True,
torch_model: Any | None = None,
) -> dict[str, Any]:
"""
Runs multiple times the same inference with onnxruntime.
Expand All @@ -632,6 +678,7 @@ def run_onnx_inference(
repeat: number of iterations to repeat
verbose: verbosity
ort_optimize: enable, disable onnxruntime optimizations
torch_model: if not empty, measure the discrepancies
Returns:
statistcs
Expand Down Expand Up @@ -668,16 +715,26 @@ def run_onnx_inference(
print(f"[run_inference] created session in {end}")
print(f"[run_inference] start {warmup} warmup iterations")

if torch_model:
expected = [
torch_model(*example_inputs[i % len(example_inputs)]) for i in range(warmup)
]

got = []

Check warning on line 723 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L723

Added line #L723 was not covered by tests
iterations = []
begin = time.perf_counter()
for i in range(warmup):
t0 = time.perf_counter()
wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)])
got.append(wrapped_session.run_dlpack(*example_inputs[i % len(example_inputs)]))

Check warning on line 728 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L728

Added line #L728 was not covered by tests
iterations.append(time.perf_counter() - t0)
end = time.perf_counter() - begin
stats["warmup"] = warmup
stats["warmup_time"] = end / warmup
stats["warmup_iter"] = iterations
if torch_model:
abs_err, rel_err = measure_discrepancies(expected, got)
stats["discrepancies_abs"] = abs_err
stats["discrepancies_rel"] = rel_err

Check warning on line 737 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L735-L737

Added lines #L735 - L737 were not covered by tests

if verbose:
print(f"[run_inference] warmup done in {time.perf_counter() - begin}")
Expand Down
1 change: 1 addition & 0 deletions onnxscript/tools/benchmark/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def main(args=None):
repeat=kwargs["repeat"],
verbose=kwargs["verbose"],
ort_optimize=kwargs["ort_optimize"],
torch_model=model,
)

print("[export_model] end")
Expand Down

0 comments on commit 5735461

Please sign in to comment.