Skip to content

Commit

Permalink
Dump more stats on output mismatch (#2328)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2328

Helps with debugging random numerical issues

Reviewed By: mcr229

Differential Revision: D54210764
  • Loading branch information
digantdesai authored and facebook-github-bot committed Mar 11, 2024
1 parent 0f28206 commit d7aaa85
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,12 +501,25 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
assert len(model_output) == len(ref_output)

for i in range(len(model_output)):
model = model_output[i]
ref = ref_output[i]
assert torch.allclose(
model_output[i],
ref_output[i],
model,
ref,
atol=atol,
rtol=rtol,
), f" Output {i} does not match reference output. Max difference: {torch.max(torch.abs(model_output[i] - ref_output[i]))}"
), (
f"Output {i} does not match reference output.\n"
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}.\n"
f"\t-- Model vs. Reference --\n"
f"\t Numel: {model.numel()}, {ref.numel()}\n"
f"\tMedian: {model.median()}, {ref.median()}\n"
f"\t Mean: {model.mean()}, {ref.mean()}\n"
f"\t Max: {model.max()}, {ref.max()}\n"
f"\t Min: {model.min()}, {ref.min()}\n"
)

def compare_outputs(self, atol=1e-03, rtol=1e-03, qtol=0):
"""
Expand Down

0 comments on commit d7aaa85

Please sign in to comment.