diff --git a/onnxscript/testing/__init__.py b/onnxscript/testing/__init__.py index e62a44d9a..bacfe9777 100644 --- a/onnxscript/testing/__init__.py +++ b/onnxscript/testing/__init__.py @@ -8,6 +8,7 @@ ] import difflib +import math from typing import Any, Collection, Sequence import google.protobuf.message @@ -448,6 +449,14 @@ def assert_onnx_proto_equal( error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}, a_value_i: {a_value_i}, b_value_i: {b_value_i}" raise AssertionError(error_message) from e elif a_value_i != b_value_i: + if ( + isinstance(a_value_i, float) + and isinstance(b_value_i, float) + and math.isnan(a_value_i) + and math.isnan(b_value_i) + ): + # Consider NaNs equal + continue error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}" for line in difflib.ndiff( str(a_value_i).splitlines(), str(b_value_i).splitlines() @@ -459,5 +468,13 @@ def assert_onnx_proto_equal( ): assert_onnx_proto_equal(a_value, b_value) elif a_value != b_value: + if ( + isinstance(a_value, float) + and isinstance(b_value, float) + and math.isnan(a_value) + and math.isnan(b_value) + ): + # Consider NaNs equal + continue error_message = f"Field {field} not equal. field_a: {a_value}, field_b: {b_value}" raise AssertionError(error_message)