Skip to content

Commit

Permalink
Compare nan
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Apr 26, 2024
1 parent 90c078b commit 4efc5a5
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions onnxscript/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
]

import difflib
import math
from typing import Any, Collection, Sequence

import google.protobuf.message
Expand Down Expand Up @@ -448,6 +449,14 @@ def assert_onnx_proto_equal(
error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}, a_value_i: {a_value_i}, b_value_i: {b_value_i}"
raise AssertionError(error_message) from e
elif a_value_i != b_value_i:
if (
isinstance(a_value_i, float)
and isinstance(b_value_i, float)
and math.isnan(a_value_i)
and math.isnan(b_value_i)
):
# Consider NaNs equal
continue
error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}"
for line in difflib.ndiff(
str(a_value_i).splitlines(), str(b_value_i).splitlines()
Expand All @@ -459,5 +468,13 @@ def assert_onnx_proto_equal(
):
assert_onnx_proto_equal(a_value, b_value)
elif a_value != b_value:
if (
isinstance(a_value, float)
and isinstance(b_value, float)
and math.isnan(a_value)
and math.isnan(b_value)
):
# Consider NaNs equal
continue

Check warning on line 478 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L478

Added line #L478 was not covered by tests
error_message = f"Field {field} not equal. field_a: {a_value}, field_b: {b_value}"
raise AssertionError(error_message)

0 comments on commit 4efc5a5

Please sign in to comment.