Skip to content

Commit

Permalink
update threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Mar 6, 2024
1 parent 976faf4 commit d92f250
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void TestMatMulIntegerToFloat(bool is_matrix_b_constant,

if (std::is_same_v<OType, float>) {
test.AddOutput<float>("Y", {M, N}, Y_data);
test.SetOutputAbsErr("Y", 0.0001f);
test.SetOutputRelErr("Y", 0.02f);
} else {
test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(Y_data));
Expand Down
70 changes: 50 additions & 20 deletions onnxruntime/test/providers/checkers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,50 @@ namespace test {
namespace {

template <typename T>
T get_tolerance(float absolute_tolerance, float relative_tolerance, T expected_value) {
struct DefaultTolerance;

template <>
struct DefaultTolerance<double> {
static constexpr float absolute = 1e-6f;
static constexpr float relative = 1e-5f;
};

template <>
struct DefaultTolerance<float> {
static constexpr float absolute = 1e-5f;
static constexpr float relative = 1e-4f;
};

template <>
struct DefaultTolerance<MLFloat16> {
// The thresholds are estimated with PyTorch script like the following:
// x = torch.rand(1000, 1000)
// absolute = ((x + 1e-6).to(torch.float16) - x).abs().max() * 10
// x[abs(x) < absolute] = absolute
// relative = ((x - x.to(torch.float16)) / x).abs().max() * 2
static constexpr float absolute = 0.0025f;
static constexpr float relative = 0.001f;
};

template <>
struct DefaultTolerance<BFloat16> {
static constexpr float absolute = 0.02f;
static constexpr float relative = 0.01f;
};

template <typename T>
T get_tolerance(float absolute, float relative, T expected_value) {
static_assert(std::is_floating_point<T>::value, "T must be a floating point type");

// The formula is similar to numpy.isclose: https://numpy.org/doc/stable/reference/generated/numpy.isclose.html
return static_cast<T>(absolute_tolerance) + static_cast<T>(relative_tolerance) * std::abs(expected_value);
return static_cast<T>(absolute) + static_cast<T>(relative) * std::abs(expected_value);
}

template <typename T>
template <typename T, typename D> // D is the original data type
T get_tolerance(const ValidateOutputParams& params, T expected_value) {
constexpr float default_absolute_tolerance = 1e-5f;
constexpr float default_relative_tolerance = 1e-4f;
float absolute_tolerance = (params.absolute_error.has_value() ? *(params.absolute_error) : default_absolute_tolerance);
float relative_tolerance = (params.relative_error.has_value() ? *(params.relative_error) : default_relative_tolerance);
return get_tolerance<T>(absolute_tolerance, relative_tolerance, expected_value);
float absolute = (params.absolute_error.has_value() ? *(params.absolute_error) : DefaultTolerance<D>::absolute);
float relative = (params.relative_error.has_value() ? *(params.relative_error) : DefaultTolerance<D>::relative);
return get_tolerance<T>(absolute, relative, expected_value);
}

template <typename T>
Expand Down Expand Up @@ -223,14 +253,14 @@ struct TensorCheck<double> {
} else if (std::isinf(cur_expected[i])) { // Test infinity for equality
EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i;
} else {
double tolerance = has_tolerance ? get_tolerance<double>(params, cur_expected[i]) : threshold;
double tolerance = has_tolerance ? get_tolerance<double, double>(params, cur_expected[i]) : threshold;
EXPECT_NEAR(cur_expected[i], cur_actual[i], tolerance) << "i:" << i;
}
}
}
};

template <typename TypeToCheck>
template <typename T>
void InternalNumericalCheck(const Tensor& expected,
const Tensor& actual,
const ValidateOutputParams& params,
Expand All @@ -240,16 +270,16 @@ void InternalNumericalCheck(const Tensor& expected,
// deal with rare cases in which order of output data from a kernel MAY be
// undefined
Tensor expected_sorted, actual_sorted;
const TypeToCheck* cur_expected;
const TypeToCheck* cur_actual;
const T* cur_expected;
const T* cur_actual;
auto size = actual.Shape().Size();
if (params.sort_output) {
sort_expected_and_actual_buffers<TypeToCheck>(expected, expected_sorted, actual, actual_sorted);
cur_expected = expected_sorted.Data<TypeToCheck>();
cur_actual = actual_sorted.Data<TypeToCheck>();
sort_expected_and_actual_buffers<T>(expected, expected_sorted, actual, actual_sorted);
cur_expected = expected_sorted.Data<T>();
cur_actual = actual_sorted.Data<T>();
} else {
cur_expected = expected.Data<TypeToCheck>();
cur_actual = actual.Data<TypeToCheck>();
cur_expected = expected.Data<T>();
cur_actual = actual.Data<T>();
}

#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
Expand All @@ -266,7 +296,7 @@ void InternalNumericalCheck(const Tensor& expected,
} else if (std::isinf(cur_expected[i])) { // Test infinity for equality
EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i;
} else {
TypeToCheck tolerance = has_tolerance ? get_tolerance<TypeToCheck>(params, cur_expected[i]) : threshold;
T tolerance = has_tolerance ? get_tolerance<T, T>(params, cur_expected[i]) : threshold;
EXPECT_NEAR(cur_expected[i], cur_actual[i], tolerance) << "i:" << i;
}
}
Expand Down Expand Up @@ -317,7 +347,7 @@ struct TensorCheck<MLFloat16> {
} else if (std::isinf(f_expected[i])) { // Test infinity for equality
EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i;
} else {
float tolerance = has_tolerance ? get_tolerance<float>(params, f_expected[i]) : threshold;
float tolerance = has_tolerance ? get_tolerance<float, MLFloat16>(params, f_expected[i]) : threshold;
EXPECT_NEAR(f_expected[i], f_actual[i], tolerance) << "i:" << i;
}
}
Expand Down Expand Up @@ -360,7 +390,7 @@ struct TensorCheck<BFloat16> {
EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i;
} else {
float tolerance = has_tolerance
? get_tolerance<float>(params, f_expected[i])
? get_tolerance<float, BFloat16>(params, f_expected[i])
: get_tolerance<float>(abs_threshold, rel_threshold, f_expected[i]);
EXPECT_NEAR(f_expected[i], f_actual[i], tolerance) << "i:" << i;
}
Expand Down

0 comments on commit d92f250

Please sign in to comment.