diff --git a/.clang-tidy b/.clang-tidy index caecc0cb295..f4262fd3c13 100755 --- a/.clang-tidy +++ b/.clang-tidy @@ -115,3 +115,5 @@ CheckOptions: value: UPPER_CASE - key: readability-identifier-naming.MacroDefinitionPrefix value: MIGRAPHX_ + - key: readability-identifier-naming.ConstexprMethodIgnoredRegexp + value: 'quiet_NaN|signaling_NaN' diff --git a/CMakeLists.txt b/CMakeLists.txt index e4a884ffbd0..440bb55b1dc 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,7 +181,6 @@ rocm_enable_clang_tidy( -bugprone-multi-level-implicit-pointer-conversion -bugprone-signed-char-misuse -bugprone-unchecked-optional-access - -bugprone-unused-local-non-trivial-variable # Disable the aliased reserved identifiers -cert-dcl37-c -cert-dcl51-cpp diff --git a/docs/install/installing_with_package.rst b/docs/install/installing_with_package.rst index 0e9b67ff2a9..b2aa21ca741 100644 --- a/docs/install/installing_with_package.rst +++ b/docs/install/installing_with_package.rst @@ -10,7 +10,7 @@ ROCm must be installed before installing MIGraphX. See `ROCm installation for Li Installing MIGraphX using the package installer is sufficient for users who want to use the MIGraphX API. -If you want to develop for MIGraphX and contribute to the source code, see `Building MIGraphX `_ and `Developing for MIGraphX `_ +If you want to develop for MIGraphX and contribute to the source code, see :doc:`Building MIGraphX ` and :doc:`Developing for MIGraphX <../dev/contributing-to-migraphx>`. The package installer will install all the prerequisites needed for MIGraphX. diff --git a/requirements.txt b/requirements.txt index 48ed2212bde..7a9e7ff85c4 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@e454b5d06fc2f099f7de3ee43450e7a6b1efe015 -DBUILD_FAT_LIBROCKCOMPILER=On +ROCm/rocMLIR@99fc9d24714ee7eae75ef8e414df4f2dacd3af16 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 909c0f6bc26..334c9615f68 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -66,7 +66,7 @@ add_library(migraphx insert_pad.cpp instruction.cpp json.cpp - layout_nhwc.cpp + layout_convolution.cpp lexing.cpp load_save.cpp make_op.cpp @@ -89,7 +89,6 @@ add_library(migraphx propagate_constant.cpp promote_literals.cpp quantization.cpp - quantize_fp16.cpp quantize_int4.cpp quantize_8bits.cpp reduce_dims.cpp @@ -115,6 +114,7 @@ add_library(migraphx split_single_dyn_dim.cpp target.cpp tmp_dir.cpp + truncate_float.cpp value.cpp verify_args.cpp ) @@ -122,6 +122,7 @@ add_library(migraphx if(WIN32) # Due to compilation crashing, we need to use type-erased matchers on Windows. target_compile_definitions(migraphx PUBLIC MIGRAPHX_USE_TYPE_ERASED_MATCHERS=1) + target_compile_options(migraphx PUBLIC "-mno-ms-bitfields") endif() configure_file(version.h.in include/migraphx/version.h) diff --git a/src/include/migraphx/bit_cast.hpp b/src/include/migraphx/bit_cast.hpp index 951b34bc340..fc4aab2e3b6 100644 --- a/src/include/migraphx/bit_cast.hpp +++ b/src/include/migraphx/bit_cast.hpp @@ -25,6 +25,7 @@ #if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wduplicated-branches" #endif #include diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp index 445b8ebeb1e..98a4a7b10fa 100644 --- a/src/include/migraphx/float8.hpp +++ b/src/include/migraphx/float8.hpp @@ -42,6 +42,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -379,52 +380,73 @@ class numeric_limits // ================================================================================================= // define numeric limits for the new data type -// NOLINTBEGIN +// NOLINTBEGIN(cert-dcl58-cpp) namespace std { -#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ - inline bool isfinite(T x) { return not x.is_inf() and not x.is_nan(); } \ - inline bool isnan(T x) { return x.is_nan(); } \ - template <> \ - class numeric_limits : public migraphx::fp8::numeric_limits \ - { \ - }; \ - template \ - struct common_type : std::common_type \ - { \ - }; \ - template \ - struct common_type : std::common_type \ - { \ - }; \ - template <> \ - struct common_type \ - { \ - using type = T; \ - }; -MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn) -MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2) -MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz) -MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz) - -// needed to resolve between multiple ambiguous definition from previous templates -#define MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(T, U) \ - template <> \ - struct common_type : std::common_type \ - { \ - }; \ - template <> \ - struct common_type : std::common_type \ - { \ - }; +template +inline bool isfinite(migraphx::fp8::float8 x) +{ + return not x.is_inf() and not x.is_nan(); +} + +template +inline bool isnan(migraphx::fp8::float8 x) +{ + return x.is_nan(); +} + +template +class numeric_limits> + : public migraphx::fp8::numeric_limits> +{ +}; +template +struct common_type, U> : std::common_type +{ +}; +template +struct common_type> : std::common_type +{ +}; +template +struct common_type, migraphx::fp8::float8> +{ + using type = migraphx::fp8::float8; +}; + +template +struct common_type, migraphx::fp8::float8> +{ + using type = float; +}; + +template +struct common_type, + migraphx::fp8::float8> +{ + using type = float; +}; + +template +struct common_type, + migraphx::generic_float> +{ + using type = float; +}; + +template +struct common_type, migraphx::fp8::float8> + : std::common_type +{ +}; + +template +struct common_type, migraphx::generic_float> + : std::common_type +{ +}; -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e5m2) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e4m3fnuz) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e5m2fnuz) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e5m2, migraphx::fp8::fp8e4m3fnuz) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e5m2, migraphx::fp8::fp8e5m2fnuz) -MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fnuz, migraphx::fp8::fp8e5m2fnuz) } // namespace std -// NOLINTEND +// NOLINTEND(cert-dcl58-cpp) // ================================================================================================= #endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP diff --git a/src/include/migraphx/generic_float.hpp b/src/include/migraphx/generic_float.hpp new file mode 100644 index 00000000000..a09c19a3a26 --- /dev/null +++ b/src/include/migraphx/generic_float.hpp @@ -0,0 +1,479 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +constexpr unsigned int all_ones() noexcept +{ + return (1u << N) - 1u; +} + +template +constexpr int countl_zero(T value) +{ + unsigned int r = 0; + for(; value != 0u; value >>= 1u) + r++; + return 8 * sizeof(value) - r; +} + +constexpr std::size_t bit_ceil(std::size_t v) +{ + if(v <= 1) + return 1; + v--; + v |= v >> 1u; + v |= v >> 2u; + v |= v >> 4u; + v |= v >> 8u; + v |= v >> 16u; + v |= v >> 32u; + return v + 1; +} + +constexpr std::size_t integer_divide_ceil(std::size_t x, std::size_t y) +{ + return (x + y - std::size_t{1}) / y; +} + +template +struct unsigned_type +{ +}; + +template <> +struct unsigned_type<1> +{ + using type = std::uint8_t; +}; + +template <> +struct unsigned_type<2> +{ + using type = std::uint16_t; +}; + +template <> +struct unsigned_type<4> +{ + using type = std::uint32_t; +}; + +template <> +struct unsigned_type<8> +{ + using type = std::uint64_t; +}; + +struct float32_parts +{ + unsigned int mantissa : 23; + unsigned int exponent : 8; + unsigned int sign : 1; + + static constexpr unsigned int mantissa_width() { return 23; } + + static constexpr unsigned int max_exponent() { return all_ones<8>(); } + + static constexpr int exponent_bias() { return all_ones<7>(); } + + constexpr float to_float() const noexcept { return migraphx::bit_cast(*this); } +}; + +constexpr float32_parts get_parts(float f) { return migraphx::bit_cast(f); } + +#ifdef _MSC_VER +#pragma pack(push, 1) +#endif +template +struct __attribute__((packed, may_alias)) generic_float +{ + using type = typename unsigned_type::type; + + type mantissa : MantissaSize; + type exponent : ExponentSize; + type sign : 1; + + static constexpr int exponent_bias() { return all_ones(); } + + explicit constexpr generic_float(float f = 0.0) noexcept { from_float(get_parts(f)); } + + constexpr generic_float& operator=(float f) noexcept + { + from_float(get_parts(f)); + return *this; + } + + constexpr generic_float operator-() const noexcept + { + generic_float result = *this; + result.sign = not this->sign; + return result; + } + + constexpr generic_float operator+() const noexcept { return *this; } + + constexpr float to_float() const noexcept + { + float32_parts f{}; + f.sign = sign; + + if(exponent == 0) // subnormal fps + { + + if(mantissa == 0) + { + f.exponent = 0; + f.mantissa = 0; + } + else + { + type shift = 0; + f.mantissa = mantissa; + + if(MantissaSize < float32_parts::mantissa_width()) + { + shift = MantissaSize - ((sizeof(type) * 8) - countl_zero(mantissa)); + f.mantissa <<= (shift + 1u); + } + + f.exponent = float32_parts::exponent_bias() - exponent_bias() - shift; + f.mantissa = f.mantissa << (float32_parts::mantissa_width() - MantissaSize); + } + } + else if(exponent == all_ones()) + { + f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize); + f.exponent = float32_parts::max_exponent(); + } + else + { + f.mantissa = mantissa << (float32_parts::mantissa_width() - MantissaSize); + constexpr const int diff = float32_parts::exponent_bias() - exponent_bias(); + f.exponent = int(exponent) + diff; + } + + return f.to_float(); + } + + constexpr void from_float(float32_parts f) noexcept + { + sign = f.sign; + + if(f.exponent == 0) + { + exponent = 0; + mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize); + } + else if(f.exponent == float32_parts::max_exponent()) + { + exponent = all_ones(); + mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize); + } + else + { + constexpr const int diff = float32_parts::exponent_bias() - exponent_bias(); + auto e = int(f.exponent) - diff; + + if(e >= static_cast(all_ones())) + { + exponent = all_ones(); + mantissa = 0; + } + else if(e < 1) + { + exponent = 0; + + auto shift = diff - int(f.exponent); + auto shift_amount = shift + (float32_parts::mantissa_width() - MantissaSize) + 1; + + if(shift_amount < (sizeof(unsigned int) * 8)) + { + mantissa = (f.mantissa | (1u << float32_parts::mantissa_width())) >> + (shift + (float32_parts::mantissa_width() - MantissaSize) + 1); + } + else + { + mantissa = 0; + } + } + else + { + exponent = int(f.exponent) - diff; + mantissa = f.mantissa >> (float32_parts::mantissa_width() - MantissaSize); + } + } + + exponent = std::min(exponent, all_ones()); + } + + constexpr bool is_normal() const noexcept + { + return exponent != all_ones() and exponent != 0; + } + + constexpr bool is_inf() const noexcept + { + return exponent == all_ones() and mantissa == 0; + } + + constexpr bool is_nan() const noexcept + { + return exponent == all_ones() and mantissa != 0; + } + + constexpr bool is_finite() const noexcept { return exponent != all_ones(); } + + constexpr operator float() const noexcept { return this->to_float(); } + + static constexpr generic_float infinity() + { + generic_float x{}; + x.exponent = all_ones(); + return x; + } + + static constexpr generic_float snan() + { + generic_float x{}; + x.exponent = all_ones(); + x.mantissa = 1u << (MantissaSize - 2u); + return x; + } + + static constexpr generic_float qnan() + { + generic_float x{}; + x.exponent = all_ones(); + x.mantissa = 1u << (MantissaSize - 1u); + return x; + } + + static constexpr generic_float min() + { + generic_float x{}; + x.exponent = 1; + x.mantissa = 0; + return x; + } + + static constexpr generic_float denorm_min() + { + generic_float x{}; + x.exponent = 0; + x.mantissa = 1; + x.sign = 0; + return x; + } + + static constexpr generic_float lowest() + { + generic_float x{}; + x.exponent = all_ones() - 1; + x.mantissa = all_ones(); + x.sign = 1; + return x; + } + + static constexpr generic_float max() + { + generic_float x{}; + x.exponent = all_ones() - 1; + x.mantissa = all_ones(); + x.sign = 0; + return x; + } + + static constexpr generic_float epsilon() + { + generic_float x{1.0}; + x.mantissa++; + return generic_float{x.to_float() - 1.0f}; + } +// NOLINTNEXTLINE +#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \ + constexpr generic_float& operator op(const generic_float & rhs) \ + { \ + float self = *this; \ + float frhs = rhs; \ + self op frhs; \ + *this = generic_float(self); \ + return *this; \ + } + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=) + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=) + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=) + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=) +// NOLINTNEXTLINE +#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \ + friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \ + { \ + return generic_float(float(x) op float(y)); \ + } + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*) + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-) + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+) + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/) +// NOLINTNEXTLINE +#define MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(op) \ + friend constexpr bool operator op(const generic_float& x, const generic_float& y) \ + { \ + return float(x) op float(y); \ + } + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<) + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<=) + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>) + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>=) + + friend constexpr bool operator==(const generic_float& x, const generic_float& y) + { + if(not x.is_finite() or not y.is_finite()) + return false; + + if((x.mantissa == 0 and x.exponent == 0) and (y.mantissa == 0 and y.exponent == 0)) + { + return true; + } + + return std::tie(x.mantissa, x.exponent, x.sign) == std::tie(y.mantissa, y.exponent, y.sign); + } + + friend constexpr bool operator!=(const generic_float& x, const generic_float& y) + { + return not(x == y); + } + + constexpr generic_float& operator++() noexcept + { + *this += generic_float(1.0f); + return *this; + } + + const generic_float operator++(int) noexcept // NOLINT(readability-const-return-type) + { + generic_float temp = *this; + *this += generic_float(1.0f); + return temp; + } +}; +#ifdef _MSC_VER +#pragma pack(pop) +#endif + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +// NOLINTBEGIN(cert-dcl58-cpp) +namespace std { + +template +class numeric_limits> +{ + public: + static constexpr bool has_infinity = true; + static constexpr migraphx::generic_float epsilon() + { + return migraphx::generic_float::epsilon(); + } + + static constexpr migraphx::generic_float quiet_NaN() + { + return migraphx::generic_float::qnan(); + } + + static constexpr migraphx::generic_float signaling_NaN() + { + return migraphx::generic_float::snan(); + } + + static constexpr migraphx::generic_float max() + { + return migraphx::generic_float::max(); + } + + static constexpr migraphx::generic_float min() + { + return migraphx::generic_float::min(); + } + + static constexpr migraphx::generic_float lowest() + { + return migraphx::generic_float::lowest(); + } + + static constexpr migraphx::generic_float infinity() + { + return migraphx::generic_float::infinity(); + } + + static constexpr migraphx::generic_float denorm_min() + { + return migraphx::generic_float::denorm_min(); + } +}; + +template +struct common_type, T> : std::common_type +{ +}; + +template +struct common_type> : std::common_type +{ +}; + +template +struct common_type, migraphx::generic_float> +{ + using type = migraphx::generic_float; +}; + +template +struct common_type, migraphx::generic_float> +{ + using type = float; +}; + +} // namespace std +// NOLINTEND(cert-dcl58-cpp) + +#endif // MIGRAPHX_GUARD_MIGRAPHX_GENERIC_FLOAT_HPP diff --git a/src/include/migraphx/half.hpp b/src/include/migraphx/half.hpp index 3296e8c328d..b92942557a4 100644 --- a/src/include/migraphx/half.hpp +++ b/src/include/migraphx/half.hpp @@ -25,14 +25,14 @@ #ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP #define MIGRAPHX_GUARD_RTGLIB_HALF_HPP -#include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -using half = half_float::half; +using half = migraphx::generic_float<10, 5>; namespace detail { template @@ -40,14 +40,6 @@ struct deduce { using type = T; }; - -#ifdef HAS_HALF_V1 -template <> -struct deduce -{ - using type = half; -}; -#endif } // namespace detail template @@ -56,60 +48,4 @@ using deduce = typename detail::deduce::type; } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx -namespace std { - -template -struct common_type : std::common_type // NOLINT -{ -}; - -template -struct common_type : std::common_type // NOLINT -{ -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = float; -}; - -template <> -struct common_type -{ - using type = migraphx::half; -}; - -} // namespace std - #endif diff --git a/src/include/migraphx/layout_nhwc.hpp b/src/include/migraphx/layout_convolution.hpp similarity index 81% rename from src/include/migraphx/layout_nhwc.hpp rename to src/include/migraphx/layout_convolution.hpp index faf097a4d9d..9e45033a8db 100644 --- a/src/include/migraphx/layout_nhwc.hpp +++ b/src/include/migraphx/layout_convolution.hpp @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP -#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP +#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP #include #include @@ -34,14 +34,15 @@ inline namespace MIGRAPHX_INLINE_NS { struct module_pass_manager; /** - * Transform convolutions to nhwc + * Transform convolutions layout */ -struct MIGRAPHX_EXPORT layout_nhwc +struct MIGRAPHX_EXPORT layout_convolution { - std::string name() const { return "layout_nhwc"; } + bool channels_last = false; + std::string name() const { return "layout_convolution"; } void apply(module_pass_manager& mpm) const; }; } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx -#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP +#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP diff --git a/src/include/migraphx/op/dequantizelinear.hpp b/src/include/migraphx/op/dequantizelinear.hpp index 3cd2d89fd96..60500b168d6 100644 --- a/src/include/migraphx/op/dequantizelinear.hpp +++ b/src/include/migraphx/op/dequantizelinear.hpp @@ -54,7 +54,7 @@ struct dequantizelinear { MIGRAPHX_THROW("DEQUANTIZELINEAR: Zero point and input should be the same type."); } - return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()}; + return inputs[0].with_lens(inputs[1].type(), inputs[0].lens()); } argument compute(const shape& output_shape, std::vector args) const diff --git a/src/include/migraphx/op/quantizelinear.hpp b/src/include/migraphx/op/quantizelinear.hpp index 77208444bfa..7a0de31cf5a 100644 --- a/src/include/migraphx/op/quantizelinear.hpp +++ b/src/include/migraphx/op/quantizelinear.hpp @@ -63,13 +63,9 @@ struct quantizelinear } if(inputs.size() == 3) { - return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()}; + return inputs[0].with_lens(inputs[2].type(), inputs[0].lens()); } - if(out_type.has_value()) - { - return {out_type.value(), inputs[0].lens(), inputs[0].strides()}; - } - return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()}; + return inputs[0].with_lens(out_type.value_or(shape::uint8_type), inputs[0].lens()); } argument compute(const shape& output_shape, std::vector args) const diff --git a/src/include/migraphx/output_iterator.hpp b/src/include/migraphx/output_iterator.hpp index 7aced4a08a3..e4d670b8537 100644 --- a/src/include/migraphx/output_iterator.hpp +++ b/src/include/migraphx/output_iterator.hpp @@ -72,6 +72,12 @@ auto join_back_inserter(Container& c) [&](const auto& r) { c.insert(c.end(), r.begin(), r.end()); }); } +template +auto push_inserter(Container& c) +{ + return make_function_output_iterator([&](const auto& x) { c.push(x); }); +} } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx + #endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP diff --git a/src/include/migraphx/quantize_fp16.hpp b/src/include/migraphx/truncate_float.hpp similarity index 82% rename from src/include/migraphx/quantize_fp16.hpp rename to src/include/migraphx/truncate_float.hpp index 7233fdf2e2e..426a445c02a 100644 --- a/src/include/migraphx/quantize_fp16.hpp +++ b/src/include/migraphx/truncate_float.hpp @@ -21,12 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP -#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP +#ifndef MIGRAPHX_GUARD_RTGLIB_TRUNCATE_FLOAT_HPP +#define MIGRAPHX_GUARD_RTGLIB_TRUNCATE_FLOAT_HPP #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -35,12 +36,13 @@ struct program; struct module; /** - * quantize a program to fp16 + * quantize a program to fp */ -struct MIGRAPHX_EXPORT quantize_fp16_pass +struct MIGRAPHX_EXPORT truncate_float_pass { std::vector ins_names = {"all"}; - std::string name() const { return "quantize_fp16"; } + shape::type_t float_type = shape::float_type; + std::string name() const { return "truncate_float"; } void apply(module& m) const; }; diff --git a/src/instruction.cpp b/src/instruction.cpp index 47bea70379e..235c551d709 100644 --- a/src/instruction.cpp +++ b/src/instruction.cpp @@ -26,7 +26,8 @@ #include #include #include -#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -58,22 +59,43 @@ instruction::instruction(literal l) { } +struct replace_shape_order +{ + instruction_ref start; + + std::size_t location(instruction_ref x) const { return std::distance(start, x); } + + bool operator()(instruction_ref x, instruction_ref y) const + { + return location(x) > location(y); + } +}; + void instruction::replace(const shape& r) { if(r != result) { result = r; - std::deque q(output.begin(), output.end()); + if(output.empty()) + { + return; + } + auto start = std::find_if(output.front()->inputs().begin(), + output.front()->inputs().end(), + [&](instruction_ref x) { return this == as_address(x); }); + assert(as_address(*start) == this); + std::priority_queue, replace_shape_order> q( + output.begin(), output.end(), replace_shape_order{*start}); while(not q.empty()) { - instruction_ref ins = q.front(); - q.pop_front(); + instruction_ref ins = q.top(); + q.pop(); assert(ins->name() == "@return" or ins->name().front() != '@'); shape new_r = compute_shape(ins->op, ins->arguments, ins->module_args); if(new_r != ins->result) { ins->result = new_r; - std::copy(ins->output.begin(), ins->output.end(), std::back_inserter(q)); + std::copy(ins->output.begin(), ins->output.end(), migraphx::push_inserter(q)); } } } diff --git a/src/layout_nhwc.cpp b/src/layout_convolution.cpp similarity index 62% rename from src/layout_nhwc.cpp rename to src/layout_convolution.cpp index 9d2a0083a34..83acb839ce6 100644 --- a/src/layout_nhwc.cpp +++ b/src/layout_convolution.cpp @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include +#include #include #include #include @@ -32,49 +32,61 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -template -std::vector find_lasts(const module& m, Predicate pred) +namespace { +std::vector get_permutation(instruction_ref ins, const layout_convolution& lc) { - std::vector result; - fix([&](auto self, auto ins) { - if(pred(ins)) - { - result.push_back(ins); - return; - } - for(auto input : ins->inputs()) - self(input); - })(std::prev(m.end())); - return result; + if(lc.channels_last) + { + std::vector perm(ins->get_shape().ndim()); + std::iota(perm.begin() + 1, perm.end() - 1, 2); + perm.back() = 1; + return perm; + } + return find_permutation(ins->inputs().front()->get_shape()); +} + +bool skip_layout(const shape& s) +{ + return s.ndim() == 1 or s.dynamic() or s.type() == shape::tuple_type; } void preserve_output_layout(module& m) { auto last = std::prev(m.end()); - std::vector outputs; if(last->name() == "@return") - outputs = last->inputs(); - else - outputs = {last}; - - for(auto output : outputs) { - auto permutation = find_permutation(output->get_shape()); - auto layout = m.insert_instruction( - std::next(output), make_op("layout", {{"permutation", permutation}}), output); - m.replace_instruction(output, layout); + std::vector outputs; + std::transform(last->inputs().begin(), + last->inputs().end(), + std::back_inserter(outputs), + [&](instruction_ref ins) { + if(skip_layout(ins->get_shape())) + return ins; + auto permutation = find_permutation(ins->get_shape()); + return m.insert_instruction( + last, make_op("layout", {{"permutation", permutation}}), ins); + }); + m.replace_return(outputs); + } + else if(not skip_layout(last->get_shape())) + { + auto permutation = find_permutation(last->get_shape()); + m.add_instruction(make_op("layout", {{"permutation", permutation}}), last); } } -void transform_convolutions(module& m) +void transform_convolutions(module& m, const layout_convolution& lc) { for(auto ins : iterator_for(m)) { - if(ins->name() != "convolution") + if(not contains({"convolution", "quant_convolution"}, ins->name())) + continue; + if(ins->get_shape().dynamic()) continue; if(ins->get_shape().lens().size() != 4) continue; @@ -82,8 +94,9 @@ void transform_convolutions(module& m) if(v.at("group").to() > 1) continue; auto args = ins->inputs(); + auto perm = get_permutation(ins, lc); std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) { - return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i); + return m.insert_instruction(ins, make_op("layout", {{"permutation", perm}}), i); }); auto conv = m.insert_instruction(ins, ins->get_operator(), args); auto c = m.insert_instruction(ins, make_op("contiguous"), conv); @@ -102,11 +115,12 @@ void remove_layout(module& m) m.replace_instruction(ins, ins->inputs().front()); } } +} // namespace -void layout_nhwc::apply(module_pass_manager& mpm) const +void layout_convolution::apply(module_pass_manager& mpm) const { preserve_output_layout(mpm.get_module()); - transform_convolutions(mpm.get_module()); + transform_convolutions(mpm.get_module(), *this); mpm.run_pass(dead_code_elimination{}); mpm.run_pass(eliminate_contiguous{"contiguous"}); mpm.run_pass(dead_code_elimination{}); diff --git a/src/module.cpp b/src/module.cpp index 44148fd73fa..7e02478b385 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -355,7 +355,6 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref { impl->changed.notify(); assert(has_instruction(ins)); - assert(has_instruction(rep)); assert(ins != rep); if(ins == std::prev(this->end())) @@ -541,7 +540,6 @@ instruction_ref module::insert_parameter(instruction_ref ins, std::string name, instruction_ref module::replace_return(std::vector args) { impl->changed.notify(); - assert(std::all_of(args.begin(), args.end(), [&](auto ins) { return has_instruction(ins); })); auto last = std::prev(this->end()); // If there is no return then add a return if(last->name() != "@return") @@ -1124,7 +1122,7 @@ void module::debug_print(instruction_ref ins, std::cout << "Instruction not part of module" << std::endl; return; } - std::stringstream ss; + names = this->print( [&](auto x, auto ins_names) { if(x == ins) diff --git a/src/program.cpp b/src/program.cpp index cac833803b3..b39f3936371 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -1005,7 +1005,6 @@ void program::debug_print(instruction_ref ins) const return; } - std::stringstream ss; this->print(names, [&](auto x, auto ins_names) { if(x == ins) { diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 04daa5e35a3..9d05e32e67d 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -48,7 +48,7 @@ #include #endif -using half = half_float::half; +using half = migraphx::half; namespace py = pybind11; #ifdef __clang__ diff --git a/src/quantization.cpp b/src/quantization.cpp index a9b47d1d503..7e02ae66685 100644 --- a/src/quantization.cpp +++ b/src/quantization.cpp @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include @@ -69,7 +69,7 @@ void quantize_fp16(program& prog, const std::vector& ins_names) run_passes(prog, {normalize_ops{}, optimize_module{{"quantizelinear", "dequantizelinear"}}, - quantize_fp16_pass{ins_names}, + truncate_float_pass{ins_names, shape::half_type}, optimize_module{{"quantizelinear", "dequantizelinear"}}}, quant_tracer()); } diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 3fb781a0f8d..5eab6ab392b 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -416,6 +416,28 @@ void remove_qdq_pairs(module& m) } } +void remove_zero_point(module& m) +{ + for(auto ins : iterator_for(m)) + { + if(ins->name() != "dequantizelinear") + continue; + if(ins->inputs().size() != 3) + continue; + auto zp = ins->inputs().at(2); + if(not zp->can_eval()) + continue; + auto a = zp->eval(); + bool is_zero = false; + a.visit([&](auto t) { + is_zero = std::all_of(t.begin(), t.end(), [](auto x) { return float_equal(x, 0); }); + }); + if(not is_zero) + continue; + m.replace_instruction(ins, ins->get_operator(), ins->inputs().at(0), ins->inputs().at(1)); + } +} + void add_int4_pack_unpack_pair(module& m) { for(auto ins : iterator_for(m)) @@ -446,6 +468,8 @@ void simplify_qdq::apply(module& m) const remove_qdq_pairs(m); migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); match::find_matches(m, match_qlinear_reused{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + remove_zero_point(m); } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/cpu/target.cpp b/src/targets/cpu/target.cpp index 6e4e4051a80..e148aa5b6f3 100644 --- a/src/targets/cpu/target.cpp +++ b/src/targets/cpu/target.cpp @@ -33,7 +33,6 @@ #include #include #include -#include #include #include #include diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index 39ca6b3f4e1..3a66249a88a 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -70,8 +70,14 @@ hipDataType get_type_hipblas(shape::type_t type) case shape::int32_type: return HIP_R_32I; case shape::uint32_type: return HIP_R_32U; case shape::fp8e4m3fnuz_type: return HIP_R_8F_E4M3_FNUZ; +// TODO can remove this preprocessor conditional when hip verison defaults to have these types +#ifdef ROCM_USE_FLOAT8 + case shape::fp8e4m3fn_type: return HIP_R_8F_E4M3; + case shape::fp8e5m2_type: return HIP_R_8F_E5M2; +#else case shape::fp8e4m3fn_type: case shape::fp8e5m2_type: +#endif case shape::tuple_type: case shape::bool_type: case shape::uint16_type: diff --git a/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp b/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp index cdedd9cfb07..8186767289f 100644 --- a/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp +++ b/src/targets/gpu/include/migraphx/gpu/code_object_op.hpp @@ -55,7 +55,8 @@ struct code_object_op f(self.global, "global"), f(self.local, "local"), f(self.expected_inputs, "expected_inputs"), - f(self.output, "output")); + f(self.output, "output"), + f(self.output_arg, "output_arg")); } value attributes() const { return {{"group", group()}}; } @@ -83,6 +84,8 @@ struct code_object_op os << "symbol_name=" << op.symbol_name << ","; os << "global=" << op.global << ","; os << "local=" << op.local << ","; + if(op.output_arg != -1) + os << "output_arg=" << op.output_arg << ","; os << "]"; return os; } diff --git a/src/targets/gpu/jit/mlir.cpp b/src/targets/gpu/jit/mlir.cpp index 88b1594bc90..4893743c2bc 100644 --- a/src/targets/gpu/jit/mlir.cpp +++ b/src/targets/gpu/jit/mlir.cpp @@ -125,14 +125,56 @@ struct mlir_compiler : compiler return {std::vector{mco.cop}, [=](module& m, instruction_ref ins, const std::vector& ops) { std::vector inputs = ins->inputs(); + + // Tuple inputs not supported + assert(std::all_of(inputs.begin(), inputs.end() - 1, [](auto i) { + return i->get_shape().sub_shapes().empty(); + })); + + // Multiple output case (allocate ins will give a tuple) + std::vector flat_inputs(inputs); + bool multi_out = not flat_inputs.back()->get_shape().sub_shapes().empty(); + if(multi_out) + { + auto allocs = flat_inputs.back(); + flat_inputs.pop_back(); + auto sub_shape_idx = range(allocs->get_shape().sub_shapes().size()); + std::transform(sub_shape_idx.begin(), + sub_shape_idx.end(), + std::back_inserter(flat_inputs), + [&](int i) { + return m.insert_instruction( + ins, + migraphx::make_op("get_tuple_elem", {{"index", i}}), + allocs); + }); + } + std::vector tuple_replacements; + for(const auto i : range(mco.prefill_indices.size())) { auto prefilled_ins = m.insert_instruction( ins, migraphx::make_op("hip::fill", {{"value", mco.prefill_values[i]}}), - inputs[mco.prefill_indices[i]]); - replace(inputs, inputs[mco.prefill_indices[i]], prefilled_ins); + flat_inputs[mco.prefill_indices[i]]); + if(not multi_out or mco.prefill_indices[i] < inputs.size() - 1) + { + replace(inputs, inputs[mco.prefill_indices[i]], prefilled_ins); + } + else + { + tuple_replacements.push_back(prefilled_ins); + } } + + if(multi_out and not tuple_replacements.empty()) + { + // Add identity to make sure fill operations happen before kernel call + tuple_replacements.insert(tuple_replacements.begin(), inputs.back()); + inputs.back() = m.insert_instruction( + ins, migraphx::make_op("identity"), tuple_replacements); + } + auto mlir = insert_mlir(m, ins, any_cast(ops.front()), inputs); return m.replace_instruction(ins, mlir); }, @@ -212,7 +254,7 @@ struct mlir_compiler : compiler const operation&, bool exhaustive) const { - static const auto mxr_loc = string_value_of(MIGRAPHX_MLIR_DUMP_TO_MXR{}); + static const auto mxr_loc = string_value_of(MIGRAPHX_MLIR_DUMP_TO_MXR{}); static const auto mlir_loc = string_value_of(MIGRAPHX_MLIR_DUMP{}); auto shapes = to_shapes(ins->inputs()); diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index c8fda0d9b7d..a0edac5eb17 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -35,7 +35,7 @@ #include #include #include -#include +#include #include #include #include @@ -82,6 +82,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM) std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { @@ -130,9 +131,12 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_fp8e4m3fnuz_ops.insert("argmin"); std::set unsupported_fp8ocp_ops = {}; - // TODO update with hipBLASLt support - unsupported_fp8ocp_ops.insert("dot"); - unsupported_fp8ocp_ops.insert("quant_dot"); + // TODO: remove this when the flag is removed + if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{})) + { + unsupported_fp8ocp_ops.insert("dot"); + unsupported_fp8ocp_ops.insert("quant_dot"); + } #if MIGRAPHX_USE_MIOPEN // MIOpen doesn't have support for fp8 pooling yet. unsupported_fp8ocp_ops.insert("pooling"); @@ -141,6 +145,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti { unsupported_fp8ocp_ops.insert("convolution"); unsupported_fp8ocp_ops.insert("quant_convolution"); + unsupported_fp8ocp_ops.insert("dot"); + unsupported_fp8ocp_ops.insert("quant_dot"); } // add all device kernels unsupported_fp8ocp_ops.insert("logsoftmax"); @@ -183,7 +189,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, - enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}), + layout_convolution{.channels_last = enabled(MIGRAPHX_ENABLE_NHWC{})}, dead_code_elimination{}, prefuse_ops{}, dead_code_elimination{}, diff --git a/src/quantize_fp16.cpp b/src/truncate_float.cpp similarity index 90% rename from src/quantize_fp16.cpp rename to src/truncate_float.cpp index 2e7e9f00a9e..15f807684d3 100644 --- a/src/quantize_fp16.cpp +++ b/src/truncate_float.cpp @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include #include @@ -35,7 +35,8 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -static void quantize_module(module& m, const std::vector& ins_names) +static void +quantize_module(module& m, const std::vector& ins_names, shape::type_t float_type) { for(auto ins : iterator_for(m)) { @@ -52,14 +53,14 @@ static void quantize_module(module& m, const std::vector& ins_names auto mod_inputs = ins->module_inputs(); auto s = ins->get_shape(); - // Convert each of the inputs that are floating point to fp16 + // Convert each of the inputs that are floating point to float type auto inputs = ins->inputs(); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { auto input_type = input->get_shape().type(); if(input_type != shape::float_type and input_type != shape::double_type) return input; return m.insert_instruction( - ins, make_op("convert", {{"target_type", shape::half_type}}), input); + ins, make_op("convert", {{"target_type", float_type}}), input); }); // Insert quantized ins @@ -71,13 +72,13 @@ static void quantize_module(module& m, const std::vector& ins_names auto outputs = ins->outputs(); std::transform( outputs.begin(), outputs.end(), outputs.begin(), [&](const auto gte_ins) { - auto gte_ins_half = + auto gte_ins_float_type = m.insert_instruction(ins, gte_ins->get_operator(), converted_ins); // Convert back to output type after quantizing auto gte_converted = m.insert_instruction( ins, make_op("convert", {{"target_type", gte_ins->get_shape().type()}}), - gte_ins_half); + gte_ins_float_type); // Replace output instruction return m.replace_instruction(gte_ins, gte_converted); }); @@ -96,7 +97,7 @@ static void quantize_module(module& m, const std::vector& ins_names } } -void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); } +void truncate_float_pass::apply(module& m) const { quantize_module(m, ins_names, float_type); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx diff --git a/test/float32.cpp b/test/float32.cpp new file mode 100644 index 00000000000..cf6ad1f12ad --- /dev/null +++ b/test/float32.cpp @@ -0,0 +1,63 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include "test.hpp" +#include + +#include + +using fp32 = migraphx::generic_float<23, 8>; + +template +bool bit_equal(const T& x, const U& y) +{ + static_assert(sizeof(T) == sizeof(U)); + using type = std::array; + return migraphx::bit_cast(x) == migraphx::bit_cast(y); +} +// NOLINTNEXTLINE +#define MIGRAPHX_CHECK_FLOAT(x, y) \ + CHECK(bit_equal(x, y)); \ + CHECK(bit_equal(x, y.to_float())); \ + CHECK(bit_equal(fp32{x}, y)); \ + CHECK(bit_equal(fp32{x}.to_float(), y.to_float())) + +TEST_CASE(fp32_values_working) +{ + MIGRAPHX_CHECK_FLOAT(1.0f, fp32{1.0f}); + MIGRAPHX_CHECK_FLOAT(-1.0f, fp32{-1.0f}); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::min(), fp32::min()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::lowest(), fp32::lowest()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::max(), fp32::max()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::epsilon(), fp32::epsilon()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::denorm_min(), fp32::denorm_min()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::infinity(), fp32::infinity()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::quiet_NaN(), fp32::qnan()); + MIGRAPHX_CHECK_FLOAT(std::numeric_limits::signaling_NaN(), fp32::snan()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/half.cpp b/test/half.cpp new file mode 100644 index 00000000000..6b0a5f330a4 --- /dev/null +++ b/test/half.cpp @@ -0,0 +1,1243 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include "test.hpp" + +#include +#include +#include + +template +bool bit_equal(const T& x, const U& y) +{ + static_assert(sizeof(T) == sizeof(U)); + using type = std::array; + return migraphx::bit_cast(x) == migraphx::bit_cast(y); +} + +TEST_CASE(check_numeric_limits) +{ + CHECK(bit_equal(std::numeric_limits::min(), uint16_t{0x0400})); + CHECK(bit_equal(std::numeric_limits::lowest(), uint16_t{0xfbff})); + CHECK(bit_equal(std::numeric_limits::max(), uint16_t{0x7bff})); + CHECK(bit_equal(std::numeric_limits::epsilon(), uint16_t{0x1400})); + CHECK(bit_equal(std::numeric_limits::denorm_min(), uint16_t{0x0001})); + CHECK(bit_equal(std::numeric_limits::infinity(), uint16_t{0x7c00})); + CHECK(bit_equal(std::numeric_limits::quiet_NaN(), uint16_t{0x7e00})); + CHECK(bit_equal(std::numeric_limits::signaling_NaN(), uint16_t{0x7d00})); +} + +const std::map& half_lut() // NOLINT(readability-function-size) +{ + static const std::map result = { + {0x0000, 0}, + {0x0058, 0.0000052452087402}, + {0x0079, 0.0000072121620178}, + {0x0097, 0.0000090003013611}, + {0x009e, 0.0000094175338745}, + {0x0125, 0.0000174641609192}, + {0x0167, 0.0000213980674744}, + {0x0196, 0.0000241994857788}, + {0x01c4, 0.0000269412994385}, + {0x01c8, 0.0000271797180176}, + {0x0236, 0.0000337362289429}, + {0x029f, 0.0000399947166443}, + {0x02bf, 0.0000419020652771}, + {0x02d6, 0.0000432729721069}, + {0x03a6, 0.0000556707382202}, + {0x03b7, 0.0000566840171814}, + {0x03d4, 0.0000584125518799}, + {0x03d8, 0.000058650970459}, + {0x03ed, 0.0000599026679993}, + {0x0427, 0.0000633597373962}, + {0x0430, 0.0000638961791992}, + {0x0435, 0.0000641942024231}, + {0x0454, 0.0000660419464111}, + {0x047a, 0.0000683069229126}, + {0x04b6, 0.0000718832015991}, + {0x056a, 0.0000826120376587}, + {0x056f, 0.0000829100608826}, + {0x0584, 0.0000841617584229}, + {0x05a1, 0.0000858902931213}, + {0x05a4, 0.0000860691070557}, + {0x05b8, 0.0000872611999512}, + {0x05bc, 0.0000874996185303}, + {0x0635, 0.0000947117805481}, + {0x0641, 0.0000954270362854}, + {0x0686, 0.0000995397567749}, + {0x0694, 0.0001003742218018}, + {0x06db, 0.0001046061515808}, + {0x0725, 0.0001090168952942}, + {0x0777, 0.0001139044761658}, + {0x07b2, 0.0001174211502075}, + {0x0812, 0.0001242160797119}, + {0x082e, 0.0001275539398193}, + {0x0859, 0.00013267993927}, + {0x0895, 0.0001398324966431}, + {0x08af, 0.0001429319381714}, + {0x08fc, 0.0001521110534668}, + {0x092e, 0.0001580715179443}, + {0x0971, 0.0001660585403442}, + {0x0991, 0.0001698732376099}, + {0x09ca, 0.0001766681671143}, + {0x0a63, 0.0001949071884155}, + {0x0a8e, 0.0002000331878662}, + {0x0a93, 0.000200629234314}, + {0x0b2a, 0.0002186298370361}, + {0x0b3a, 0.0002205371856689}, + {0x0b3c, 0.000220775604248}, + {0x0b4e, 0.00022292137146}, + {0x0bae, 0.0002343654632568}, + {0x0bff, 0.0002440214157104}, + {0x0c08, 0.0002460479736328}, + {0x0c56, 0.0002646446228027}, + {0x0c61, 0.0002672672271729}, + {0x0c70, 0.0002708435058594}, + {0x0c7c, 0.0002737045288086}, + {0x0cd8, 0.0002956390380859}, + {0x0cdd, 0.0002968311309814}, + {0x0d05, 0.0003063678741455}, + {0x0d61, 0.0003283023834229}, + {0x0d85, 0.0003368854522705}, + {0x0d8c, 0.0003385543823242}, + {0x0d90, 0.0003395080566406}, + {0x0d9e, 0.000342845916748}, + {0x0da5, 0.0003445148468018}, + {0x0dda, 0.0003571510314941}, + {0x0dde, 0.0003581047058105}, + {0x0df6, 0.000363826751709}, + {0x0eec, 0.000422477722168}, + {0x0f1c, 0.0004339218139648}, + {0x0f99, 0.0004637241363525}, + {0x0fac, 0.0004682540893555}, + {0x0fb0, 0.0004692077636719}, + {0x0ff5, 0.0004856586456299}, + {0x107f, 0.0005488395690918}, + {0x1096, 0.0005598068237305}, + {0x10c8, 0.0005836486816406}, + {0x10e9, 0.0005993843078613}, + {0x110a, 0.000615119934082}, + {0x118a, 0.000676155090332}, + {0x11b5, 0.0006966590881348}, + {0x1293, 0.0008025169372559}, + {0x133f, 0.0008845329284668}, + {0x1342, 0.0008859634399414}, + {0x1372, 0.0009088516235352}, + {0x13cf, 0.000953197479248}, + {0x140c, 0.0009880065917969}, + {0x1437, 0.0010290145874023}, + {0x14a3, 0.0011320114135742}, + {0x14a6, 0.0011348724365234}, + {0x14b2, 0.0011463165283203}, + {0x14ba, 0.0011539459228516}, + {0x14d9, 0.0011835098266602}, + {0x14da, 0.0011844635009766}, + {0x14e7, 0.0011968612670898}, + {0x14fe, 0.0012187957763672}, + {0x1521, 0.0012521743774414}, + {0x153d, 0.0012788772583008}, + {0x15ad, 0.0013856887817383}, + {0x15fd, 0.0014619827270508}, + {0x1649, 0.0015344619750977}, + {0x1658, 0.0015487670898438}, + {0x168a, 0.0015964508056641}, + {0x169d, 0.0016145706176758}, + {0x16b3, 0.0016355514526367}, + {0x16c9, 0.0016565322875977}, + {0x16d1, 0.0016641616821289}, + {0x16e0, 0.001678466796875}, + {0x170a, 0.0017185211181641}, + {0x176d, 0.0018129348754883}, + {0x185b, 0.0021266937255859}, + {0x185e, 0.0021324157714844}, + {0x187e, 0.0021934509277344}, + {0x18ca, 0.0023384094238281}, + {0x18e9, 0.0023975372314453}, + {0x1901, 0.0024433135986328}, + {0x191e, 0.0024986267089844}, + {0x1963, 0.0026302337646484}, + {0x199f, 0.0027446746826172}, + {0x19b2, 0.0027809143066406}, + {0x19d4, 0.0028457641601562}, + {0x1a31, 0.0030231475830078}, + {0x1a4a, 0.0030708312988281}, + {0x1a7a, 0.0031623840332031}, + {0x1ace, 0.0033226013183594}, + {0x1b03, 0.0034236907958984}, + {0x1b22, 0.0034828186035156}, + {0x1d49, 0.0051612854003906}, + {0x1d5a, 0.0052261352539062}, + {0x1d6c, 0.0052947998046875}, + {0x1e02, 0.0058670043945312}, + {0x1e19, 0.0059547424316406}, + {0x1e4c, 0.0061492919921875}, + {0x1eb3, 0.0065422058105469}, + {0x1f32, 0.0070266723632812}, + {0x1f36, 0.0070419311523438}, + {0x1f41, 0.0070838928222656}, + {0x1f7a, 0.0073013305664062}, + {0x1f8d, 0.0073738098144531}, + {0x200b, 0.0078964233398438}, + {0x205f, 0.0085372924804688}, + {0x2060, 0.008544921875}, + {0x2067, 0.0085983276367188}, + {0x20e2, 0.0095367431640625}, + {0x2164, 0.010528564453125}, + {0x22a4, 0.012969970703125}, + {0x22b4, 0.013092041015625}, + {0x22f2, 0.0135650634765625}, + {0x230c, 0.013763427734375}, + {0x2314, 0.013824462890625}, + {0x2341, 0.0141677856445312}, + {0x2356, 0.0143280029296875}, + {0x236e, 0.0145111083984375}, + {0x2371, 0.0145339965820312}, + {0x23cd, 0.0152359008789062}, + {0x2405, 0.0157012939453125}, + {0x24a2, 0.018096923828125}, + {0x24ba, 0.018463134765625}, + {0x24e7, 0.0191497802734375}, + {0x266c, 0.02508544921875}, + {0x26a2, 0.025909423828125}, + {0x26cc, 0.02655029296875}, + {0x26f0, 0.027099609375}, + {0x271e, 0.027801513671875}, + {0x2798, 0.0296630859375}, + {0x287d, 0.035064697265625}, + {0x28a2, 0.03619384765625}, + {0x28ca, 0.03741455078125}, + {0x2933, 0.040618896484375}, + {0x298d, 0.043365478515625}, + {0x299e, 0.04388427734375}, + {0x29c0, 0.044921875}, + {0x29c2, 0.04498291015625}, + {0x29cf, 0.045379638671875}, + {0x29fa, 0.04669189453125}, + {0x2a06, 0.04705810546875}, + {0x2aa5, 0.051910400390625}, + {0x2bcb, 0.060882568359375}, + {0x2c18, 0.06396484375}, + {0x2c65, 0.06866455078125}, + {0x2c66, 0.0687255859375}, + {0x2c93, 0.07147216796875}, + {0x2d24, 0.080322265625}, + {0x2d35, 0.08135986328125}, + {0x2d4c, 0.082763671875}, + {0x2db7, 0.08929443359375}, + {0x2dec, 0.092529296875}, + {0x2e31, 0.09674072265625}, + {0x2ec9, 0.10601806640625}, + {0x2f85, 0.11749267578125}, + {0x2f94, 0.118408203125}, + {0x302b, 0.1302490234375}, + {0x3094, 0.14306640625}, + {0x3096, 0.143310546875}, + {0x30ae, 0.146240234375}, + {0x30b9, 0.1475830078125}, + {0x310c, 0.15771484375}, + {0x31bd, 0.1793212890625}, + {0x3213, 0.1898193359375}, + {0x325b, 0.1986083984375}, + {0x32aa, 0.208251953125}, + {0x32c0, 0.2109375}, + {0x32d7, 0.2137451171875}, + {0x3391, 0.2364501953125}, + {0x340d, 0.253173828125}, + {0x343d, 0.264892578125}, + {0x3566, 0.33740234375}, + {0x35e6, 0.36865234375}, + {0x35f4, 0.3720703125}, + {0x363b, 0.389404296875}, + {0x363e, 0.39013671875}, + {0x3650, 0.39453125}, + {0x3698, 0.412109375}, + {0x36e7, 0.431396484375}, + {0x36fe, 0.43701171875}, + {0x374a, 0.45556640625}, + {0x3760, 0.4609375}, + {0x3761, 0.461181640625}, + {0x379e, 0.47607421875}, + {0x37cc, 0.4873046875}, + {0x37fd, 0.499267578125}, + {0x3828, 0.51953125}, + {0x3841, 0.53173828125}, + {0x3877, 0.55810546875}, + {0x38a4, 0.580078125}, + {0x38d3, 0.60302734375}, + {0x39b2, 0.7119140625}, + {0x3a60, 0.796875}, + {0x3aa3, 0.82958984375}, + {0x3aa6, 0.8310546875}, + {0x3ac9, 0.84814453125}, + {0x3acf, 0.85107421875}, + {0x3b14, 0.884765625}, + {0x3b42, 0.9072265625}, + {0x3b5c, 0.919921875}, + {0x3bde, 0.9833984375}, + {0x3c67, 1.1005859375}, + {0x3cb5, 1.1767578125}, + {0x3cca, 1.197265625}, + {0x3cdd, 1.2158203125}, + {0x3cfc, 1.24609375}, + {0x3d1f, 1.2802734375}, + {0x3e0c, 1.51171875}, + {0x3e1c, 1.52734375}, + {0x3e5b, 1.5888671875}, + {0x3e7f, 1.6240234375}, + {0x3eae, 1.669921875}, + {0x3efe, 1.748046875}, + {0x3f3e, 1.810546875}, + {0x3f9d, 1.9033203125}, + {0x400a, 2.01953125}, + {0x4070, 2.21875}, + {0x40a0, 2.3125}, + {0x40ce, 2.40234375}, + {0x40e6, 2.44921875}, + {0x410e, 2.52734375}, + {0x4129, 2.580078125}, + {0x4144, 2.6328125}, + {0x41a4, 2.8203125}, + {0x41f3, 2.974609375}, + {0x42f1, 3.470703125}, + {0x438f, 3.779296875}, + {0x43b0, 3.84375}, + {0x43c3, 3.880859375}, + {0x43de, 3.93359375}, + {0x4483, 4.51171875}, + {0x44f8, 4.96875}, + {0x4505, 5.01953125}, + {0x45dd, 5.86328125}, + {0x45f3, 5.94921875}, + {0x460e, 6.0546875}, + {0x46ce, 6.8046875}, + {0x4704, 7.015625}, + {0x471a, 7.1015625}, + {0x475e, 7.3671875}, + {0x4761, 7.37890625}, + {0x479f, 7.62109375}, + {0x47ca, 7.7890625}, + {0x47db, 7.85546875}, + {0x47fc, 7.984375}, + {0x481e, 8.234375}, + {0x4839, 8.4453125}, + {0x483d, 8.4765625}, + {0x48ac, 9.34375}, + {0x48da, 9.703125}, + {0x4919, 10.1953125}, + {0x4950, 10.625}, + {0x4987, 11.0546875}, + {0x49bb, 11.4609375}, + {0x4a14, 12.15625}, + {0x4a92, 13.140625}, + {0x4b25, 14.2890625}, + {0x4b81, 15.0078125}, + {0x4b99, 15.1953125}, + {0x4bbe, 15.484375}, + {0x4bf8, 15.9375}, + {0x4c1f, 16.484375}, + {0x4c49, 17.140625}, + {0x4d21, 20.515625}, + {0x4d4a, 21.15625}, + {0x4d51, 21.265625}, + {0x4de2, 23.53125}, + {0x4e05, 24.078125}, + {0x4ea3, 26.546875}, + {0x4eb0, 26.75}, + {0x4f0e, 28.21875}, + {0x4f4a, 29.15625}, + {0x4f6b, 29.671875}, + {0x4fa6, 30.59375}, + {0x4fae, 30.71875}, + {0x4ff6, 31.84375}, + {0x503c, 33.875}, + {0x50e4, 39.125}, + {0x514e, 42.4375}, + {0x516b, 43.34375}, + {0x51d3, 46.59375}, + {0x5213, 48.59375}, + {0x526e, 51.4375}, + {0x52a6, 53.1875}, + {0x52b4, 53.625}, + {0x52b6, 53.6875}, + {0x52bc, 53.875}, + {0x5300, 56}, + {0x5389, 60.28125}, + {0x5406, 64.375}, + {0x5498, 73.5}, + {0x54bd, 75.8125}, + {0x54cf, 76.9375}, + {0x5502, 80.125}, + {0x558e, 88.875}, + {0x5597, 89.4375}, + {0x55eb, 94.6875}, + {0x55f6, 95.375}, + {0x5629, 98.5625}, + {0x562b, 98.6875}, + {0x5635, 99.3125}, + {0x564e, 100.875}, + {0x5671, 103.0625}, + {0x5681, 104.0625}, + {0x56d1, 109.0625}, + {0x571c, 113.75}, + {0x5756, 117.375}, + {0x5790, 121}, + {0x57fd, 127.8125}, + {0x582d, 133.625}, + {0x5869, 141.125}, + {0x58ab, 149.375}, + {0x58ad, 149.625}, + {0x58c9, 153.125}, + {0x58f7, 158.875}, + {0x5904, 160.5}, + {0x59c2, 184.25}, + {0x59e6, 188.75}, + {0x5a88, 209}, + {0x5ada, 219.25}, + {0x5aef, 221.875}, + {0x5af5, 222.625}, + {0x5b7f, 239.875}, + {0x5ba4, 244.5}, + {0x5c08, 258}, + {0x5cbf, 303.75}, + {0x5d4d, 339.25}, + {0x5dc2, 368.5}, + {0x5dc4, 369}, + {0x5e31, 396.25}, + {0x5e38, 398}, + {0x5e7c, 415}, + {0x5e8d, 419.25}, + {0x5ead, 427.25}, + {0x5eb4, 429}, + {0x5ec0, 432}, + {0x5eef, 443.75}, + {0x5f04, 449}, + {0x5f41, 464.25}, + {0x5f58, 470}, + {0x5f61, 472.25}, + {0x5f77, 477.75}, + {0x5f7b, 478.75}, + {0x6029, 532.5}, + {0x6046, 547}, + {0x6055, 554.5}, + {0x60a8, 596}, + {0x60d7, 619.5}, + {0x6139, 668.5}, + {0x6167, 691.5}, + {0x61b5, 730.5}, + {0x61c0, 736}, + {0x61e6, 755}, + {0x625b, 813.5}, + {0x62c4, 866}, + {0x62fd, 894.5}, + {0x62fe, 895}, + {0x6332, 921}, + {0x636a, 949}, + {0x6374, 954}, + {0x6376, 955}, + {0x639f, 975.5}, + {0x63d6, 1003}, + {0x6417, 1047}, + {0x642e, 1070}, + {0x6431, 1073}, + {0x644f, 1103}, + {0x6459, 1113}, + {0x645b, 1115}, + {0x6480, 1152}, + {0x648d, 1165}, + {0x649f, 1183}, + {0x64bb, 1211}, + {0x6516, 1302}, + {0x6571, 1393}, + {0x6585, 1413}, + {0x65aa, 1450}, + {0x660c, 1548}, + {0x6694, 1684}, + {0x66d0, 1744}, + {0x6721, 1825}, + {0x672d, 1837}, + {0x6734, 1844}, + {0x6766, 1894}, + {0x6773, 1907}, + {0x677d, 1917}, + {0x679a, 1946}, + {0x690f, 2590}, + {0x6934, 2664}, + {0x6955, 2730}, + {0x697d, 2810}, + {0x698e, 2844}, + {0x6a3a, 3188}, + {0x6a63, 3270}, + {0x6a67, 3278}, + {0x6a7c, 3320}, + {0x6a87, 3342}, + {0x6b07, 3598}, + {0x6b11, 3618}, + {0x6b36, 3692}, + {0x6b3c, 3704}, + {0x6b75, 3818}, + {0x6b88, 3856}, + {0x6be6, 4044}, + {0x6bee, 4060}, + {0x6c62, 4488}, + {0x6c8b, 4652}, + {0x6d30, 5312}, + {0x6d48, 5408}, + {0x6ddd, 6004}, + {0x6de9, 6052}, + {0x6e39, 6372}, + {0x6e7e, 6648}, + {0x6ea5, 6804}, + {0x6ec5, 6932}, + {0x6ee1, 7044}, + {0x6ef1, 7108}, + {0x6fa2, 7816}, + {0x6fbc, 7920}, + {0x704c, 8800}, + {0x7083, 9240}, + {0x7108, 10304}, + {0x7115, 10408}, + {0x7128, 10560}, + {0x71af, 11640}, + {0x7222, 12560}, + {0x7228, 12608}, + {0x72a5, 13608}, + {0x72e0, 14080}, + {0x72e6, 14128}, + {0x731e, 14576}, + {0x7377, 15288}, + {0x741d, 16848}, + {0x7423, 16944}, + {0x7424, 16960}, + {0x7466, 18016}, + {0x74b0, 19200}, + {0x74ce, 19680}, + {0x74f0, 20224}, + {0x754b, 21680}, + {0x7575, 22352}, + {0x7594, 22848}, + {0x75b1, 23312}, + {0x7614, 24896}, + {0x7618, 24960}, + {0x7631, 25360}, + {0x7660, 26112}, + {0x76c8, 27776}, + {0x7773, 30512}, + {0x77af, 31472}, + {0x77b9, 31632}, + {0x77de, 32224}, + {0x7844, 34944}, + {0x78d2, 39488}, + {0x7924, 42112}, + {0x793b, 42848}, + {0x79db, 47968}, + {0x7a0f, 49632}, + {0x7a1a, 49984}, + {0x7a6c, 52608}, + {0x7a99, 54048}, + {0x7ada, 56128}, + {0x7b0f, 57824}, + {0x7b15, 58016}, + {0x7b41, 59424}, + {0x7b51, 59936}, + {0x7b9c, 62336}, + {0x7ba3, 62560}, + {0x7c00, std::numeric_limits::infinity()}, + {0x7c05, std::numeric_limits::quiet_NaN()}, + {0x7c0e, std::numeric_limits::quiet_NaN()}, + {0x7c3e, std::numeric_limits::quiet_NaN()}, + {0x7c4e, std::numeric_limits::quiet_NaN()}, + {0x7c55, std::numeric_limits::quiet_NaN()}, + {0x7c58, std::numeric_limits::quiet_NaN()}, + {0x7c66, std::numeric_limits::quiet_NaN()}, + {0x7cc9, std::numeric_limits::quiet_NaN()}, + {0x7cd8, std::numeric_limits::quiet_NaN()}, + {0x7d2d, std::numeric_limits::quiet_NaN()}, + {0x7d60, std::numeric_limits::quiet_NaN()}, + {0x7d79, std::numeric_limits::quiet_NaN()}, + {0x7dc7, std::numeric_limits::quiet_NaN()}, + {0x7dcf, std::numeric_limits::quiet_NaN()}, + {0x7dd8, std::numeric_limits::quiet_NaN()}, + {0x7dfb, std::numeric_limits::quiet_NaN()}, + {0x7e0f, std::numeric_limits::quiet_NaN()}, + {0x7e56, std::numeric_limits::quiet_NaN()}, + {0x7e89, std::numeric_limits::quiet_NaN()}, + {0x7e9c, std::numeric_limits::quiet_NaN()}, + {0x7eb2, std::numeric_limits::quiet_NaN()}, + {0x7ec3, std::numeric_limits::quiet_NaN()}, + {0x7ef9, std::numeric_limits::quiet_NaN()}, + {0x7f36, std::numeric_limits::quiet_NaN()}, + {0x8040, -0.0000038146972656}, + {0x8101, -0.0000153183937073}, + {0x813d, -0.0000188946723938}, + {0x81a8, -0.0000252723693848}, + {0x81bc, -0.0000264644622803}, + {0x81c2, -0.0000268220901489}, + {0x8259, -0.00003582239151}, + {0x8330, -0.0000486373901367}, + {0x8366, -0.0000518560409546}, + {0x8392, -0.0000544786453247}, + {0x83e4, -0.0000593662261963}, + {0x83ee, -0.000059962272644}, + {0x8402, -0.0000611543655396}, + {0x845e, -0.0000666379928589}, + {0x84ac, -0.0000712871551514}, + {0x84b1, -0.0000715851783752}, + {0x84fb, -0.0000759959220886}, + {0x8546, -0.0000804662704468}, + {0x856f, -0.0000829100608826}, + {0x85b5, -0.0000870823860168}, + {0x8638, -0.0000948905944824}, + {0x8656, -0.0000966787338257}, + {0x86b9, -0.0001025795936584}, + {0x86ba, -0.0001026391983032}, + {0x86fe, -0.0001066923141479}, + {0x8731, -0.0001097321510315}, + {0x8740, -0.0001106262207031}, + {0x8793, -0.0001155734062195}, + {0x87bd, -0.0001180768013}, + {0x87f1, -0.0001211762428284}, + {0x87f4, -0.0001213550567627}, + {0x8809, -0.000123143196106}, + {0x882a, -0.0001270771026611}, + {0x8848, -0.0001306533813477}, + {0x8852, -0.0001318454742432}, + {0x8874, -0.0001358985900879}, + {0x8892, -0.0001394748687744}, + {0x88a7, -0.000141978263855}, + {0x88c8, -0.0001459121704102}, + {0x8927, -0.0001572370529175}, + {0x892a, -0.0001575946807861}, + {0x8989, -0.0001689195632935}, + {0x89b9, -0.0001746416091919}, + {0x8b18, -0.0002164840698242}, + {0x8b4b, -0.0002225637435913}, + {0x8b62, -0.000225305557251}, + {0x8b7f, -0.0002287626266479}, + {0x8bca, -0.0002377033233643}, + {0x8bcf, -0.000238299369812}, + {0x8bff, -0.0002440214157104}, + {0x8c0b, -0.0002467632293701}, + {0x8c55, -0.0002644062042236}, + {0x8c63, -0.0002677440643311}, + {0x8d53, -0.0003249645233154}, + {0x8dba, -0.0003495216369629}, + {0x8e03, -0.0003669261932373}, + {0x8e82, -0.0003972053527832}, + {0x8e9c, -0.0004034042358398}, + {0x8faa, -0.0004677772521973}, + {0x902f, -0.0005106925964355}, + {0x9051, -0.0005269050598145}, + {0x9066, -0.0005369186401367}, + {0x907e, -0.0005483627319336}, + {0x9080, -0.00054931640625}, + {0x908e, -0.0005559921264648}, + {0x9102, -0.0006113052368164}, + {0x91eb, -0.0007224082946777}, + {0x9215, -0.0007424354553223}, + {0x9252, -0.0007715225219727}, + {0x9294, -0.0008029937744141}, + {0x9297, -0.0008044242858887}, + {0x933d, -0.0008835792541504}, + {0x936f, -0.0009074211120605}, + {0x93aa, -0.0009355545043945}, + {0x93f2, -0.0009698867797852}, + {0x941d, -0.0010042190551758}, + {0x945a, -0.0010623931884766}, + {0x94ad, -0.0011415481567383}, + {0x94d2, -0.0011768341064453}, + {0x951c, -0.0012474060058594}, + {0x9520, -0.001251220703125}, + {0x952f, -0.0012655258178711}, + {0x953f, -0.0012807846069336}, + {0x9549, -0.0012903213500977}, + {0x95c6, -0.0014095306396484}, + {0x9602, -0.0014667510986328}, + {0x969b, -0.001612663269043}, + {0x96fa, -0.0017032623291016}, + {0x977d, -0.0018281936645508}, + {0x97c3, -0.0018949508666992}, + {0x97c6, -0.0018978118896484}, + {0x97db, -0.001917839050293}, + {0x97f9, -0.0019464492797852}, + {0x983f, -0.0020732879638672}, + {0x984e, -0.0021018981933594}, + {0x985a, -0.0021247863769531}, + {0x988c, -0.0022201538085938}, + {0x990d, -0.0024662017822266}, + {0x9958, -0.0026092529296875}, + {0x9971, -0.0026569366455078}, + {0x9a4e, -0.0030784606933594}, + {0x9a8f, -0.0032024383544922}, + {0x9abe, -0.0032920837402344}, + {0x9ace, -0.0033226013183594}, + {0x9b1e, -0.0034751892089844}, + {0x9b3e, -0.0035362243652344}, + {0x9b77, -0.0036449432373047}, + {0x9b89, -0.0036792755126953}, + {0x9b90, -0.003692626953125}, + {0x9bec, -0.0038681030273438}, + {0x9c03, -0.0039176940917969}, + {0x9c75, -0.0043525695800781}, + {0x9d6c, -0.0052947998046875}, + {0x9d74, -0.0053253173828125}, + {0x9da7, -0.0055198669433594}, + {0x9e73, -0.0062980651855469}, + {0x9e94, -0.0064239501953125}, + {0x9f17, -0.0069236755371094}, + {0x9f3a, -0.0070571899414062}, + {0x9f6c, -0.0072479248046875}, + {0x9f89, -0.0073585510253906}, + {0x9fbd, -0.0075569152832031}, + {0xa003, -0.0078353881835938}, + {0xa014, -0.007965087890625}, + {0xa019, -0.0080032348632812}, + {0xa01d, -0.0080337524414062}, + {0xa090, -0.0089111328125}, + {0xa1cf, -0.0113449096679688}, + {0xa1dd, -0.0114517211914062}, + {0xa249, -0.0122756958007812}, + {0xa26d, -0.0125503540039062}, + {0xa288, -0.01275634765625}, + {0xa2fb, -0.0136337280273438}, + {0xa390, -0.0147705078125}, + {0xa3b3, -0.0150375366210938}, + {0xa3ed, -0.0154800415039062}, + {0xa434, -0.01641845703125}, + {0xa476, -0.017425537109375}, + {0xa571, -0.0212554931640625}, + {0xa57d, -0.0214385986328125}, + {0xa597, -0.0218353271484375}, + {0xa5d1, -0.0227203369140625}, + {0xa5f9, -0.0233306884765625}, + {0xa680, -0.025390625}, + {0xa6e3, -0.0269012451171875}, + {0xa6f0, -0.027099609375}, + {0xa72d, -0.0280303955078125}, + {0xa77e, -0.029266357421875}, + {0xa7d0, -0.030517578125}, + {0xa7ee, -0.030975341796875}, + {0xa7f3, -0.0310516357421875}, + {0xa80c, -0.0316162109375}, + {0xa827, -0.032440185546875}, + {0xa89f, -0.036102294921875}, + {0xa8a0, -0.0361328125}, + {0xa8a5, -0.036285400390625}, + {0xa948, -0.041259765625}, + {0xaa0c, -0.0472412109375}, + {0xaa16, -0.04754638671875}, + {0xaa9a, -0.05157470703125}, + {0xaaeb, -0.054046630859375}, + {0xab5c, -0.0574951171875}, + {0xac7e, -0.0701904296875}, + {0xad33, -0.08123779296875}, + {0xad37, -0.08148193359375}, + {0xad90, -0.0869140625}, + {0xada0, -0.087890625}, + {0xade5, -0.09210205078125}, + {0xadf8, -0.09326171875}, + {0xae02, -0.0938720703125}, + {0xae04, -0.093994140625}, + {0xae4f, -0.09857177734375}, + {0xae63, -0.09979248046875}, + {0xaebe, -0.1053466796875}, + {0xaee1, -0.10748291015625}, + {0xaef9, -0.10894775390625}, + {0xaf0b, -0.11004638671875}, + {0xaf78, -0.11669921875}, + {0xaf7d, -0.11700439453125}, + {0xaf7f, -0.11712646484375}, + {0xaf8c, -0.117919921875}, + {0xafcb, -0.12176513671875}, + {0xb06b, -0.1380615234375}, + {0xb07b, -0.1400146484375}, + {0xb088, -0.1416015625}, + {0xb0b2, -0.146728515625}, + {0xb0ed, -0.1539306640625}, + {0xb0f9, -0.1553955078125}, + {0xb16c, -0.16943359375}, + {0xb189, -0.1729736328125}, + {0xb1c5, -0.1802978515625}, + {0xb1f7, -0.1864013671875}, + {0xb22d, -0.1929931640625}, + {0xb23c, -0.19482421875}, + {0xb258, -0.1982421875}, + {0xb2c7, -0.2117919921875}, + {0xb2de, -0.214599609375}, + {0xb2e1, -0.2149658203125}, + {0xb317, -0.2215576171875}, + {0xb31d, -0.2222900390625}, + {0xb3ef, -0.2479248046875}, + {0xb3f8, -0.2490234375}, + {0xb45a, -0.27197265625}, + {0xb548, -0.330078125}, + {0xb5d8, -0.365234375}, + {0xb64e, -0.39404296875}, + {0xb69f, -0.413818359375}, + {0xb6e6, -0.43115234375}, + {0xb6ed, -0.432861328125}, + {0xb6f7, -0.435302734375}, + {0xb79a, -0.47509765625}, + {0xb7b6, -0.48193359375}, + {0xb7ee, -0.49560546875}, + {0xb856, -0.5419921875}, + {0xb8c0, -0.59375}, + {0xb96f, -0.67919921875}, + {0xb9a5, -0.70556640625}, + {0xba1e, -0.7646484375}, + {0xba2d, -0.77197265625}, + {0xba48, -0.78515625}, + {0xba65, -0.79931640625}, + {0xbaaf, -0.83544921875}, + {0xbab0, -0.8359375}, + {0xbb12, -0.8837890625}, + {0xbb35, -0.90087890625}, + {0xbb47, -0.90966796875}, + {0xbb97, -0.94873046875}, + {0xbba3, -0.95458984375}, + {0xbbcb, -0.97412109375}, + {0xbbe8, -0.98828125}, + {0xbbee, -0.9912109375}, + {0xbd03, -1.2529296875}, + {0xbd4b, -1.3232421875}, + {0xbd4c, -1.32421875}, + {0xbd8a, -1.384765625}, + {0xbdb6, -1.427734375}, + {0xbde1, -1.4697265625}, + {0xbe04, -1.50390625}, + {0xbe50, -1.578125}, + {0xbe54, -1.58203125}, + {0xbe6a, -1.603515625}, + {0xbf31, -1.7978515625}, + {0xbf87, -1.8818359375}, + {0xbfa2, -1.908203125}, + {0xc016, -2.04296875}, + {0xc074, -2.2265625}, + {0xc0ca, -2.39453125}, + {0xc100, -2.5}, + {0xc1b7, -2.857421875}, + {0xc1b9, -2.861328125}, + {0xc1d3, -2.912109375}, + {0xc23f, -3.123046875}, + {0xc2d5, -3.416015625}, + {0xc32f, -3.591796875}, + {0xc3e3, -3.943359375}, + {0xc412, -4.0703125}, + {0xc49a, -4.6015625}, + {0xc4ca, -4.7890625}, + {0xc4cf, -4.80859375}, + {0xc523, -5.13671875}, + {0xc55d, -5.36328125}, + {0xc5aa, -5.6640625}, + {0xc604, -6.015625}, + {0xc61b, -6.10546875}, + {0xc642, -6.2578125}, + {0xc68b, -6.54296875}, + {0xc69e, -6.6171875}, + {0xc6b0, -6.6875}, + {0xc6ca, -6.7890625}, + {0xc71e, -7.1171875}, + {0xc721, -7.12890625}, + {0xc73b, -7.23046875}, + {0xc7d4, -7.828125}, + {0xc831, -8.3828125}, + {0xc89a, -9.203125}, + {0xc8be, -9.484375}, + {0xc8dc, -9.71875}, + {0xc8e4, -9.78125}, + {0xc8fa, -9.953125}, + {0xc8fe, -9.984375}, + {0xc969, -10.8203125}, + {0xca0f, -12.1171875}, + {0xca1a, -12.203125}, + {0xca6f, -12.8671875}, + {0xca7b, -12.9609375}, + {0xca8f, -13.1171875}, + {0xcaca, -13.578125}, + {0xcafd, -13.9765625}, + {0xcb05, -14.0390625}, + {0xcb6b, -14.8359375}, + {0xcbaf, -15.3671875}, + {0xcbb4, -15.40625}, + {0xcbdf, -15.7421875}, + {0xcc2d, -16.703125}, + {0xcc74, -17.8125}, + {0xccac, -18.6875}, + {0xcd11, -20.265625}, + {0xce04, -24.0625}, + {0xce0f, -24.234375}, + {0xceaf, -26.734375}, + {0xceb8, -26.875}, + {0xcf36, -28.84375}, + {0xcfad, -30.703125}, + {0xd019, -32.78125}, + {0xd08d, -36.40625}, + {0xd115, -40.65625}, + {0xd119, -40.78125}, + {0xd128, -41.25}, + {0xd1a4, -45.125}, + {0xd1b7, -45.71875}, + {0xd1b8, -45.75}, + {0xd203, -48.09375}, + {0xd20a, -48.3125}, + {0xd28b, -52.34375}, + {0xd2ac, -53.375}, + {0xd2ae, -53.4375}, + {0xd2c5, -54.15625}, + {0xd2f2, -55.5625}, + {0xd326, -57.1875}, + {0xd337, -57.71875}, + {0xd343, -58.09375}, + {0xd34e, -58.4375}, + {0xd40c, -64.75}, + {0xd43b, -67.6875}, + {0xd45a, -69.625}, + {0xd464, -70.25}, + {0xd4c3, -76.1875}, + {0xd505, -80.3125}, + {0xd52d, -82.8125}, + {0xd5cf, -92.9375}, + {0xd5f0, -95}, + {0xd607, -96.4375}, + {0xd635, -99.3125}, + {0xd63d, -99.8125}, + {0xd644, -100.25}, + {0xd658, -101.5}, + {0xd789, -120.5625}, + {0xd863, -140.375}, + {0xd866, -140.75}, + {0xd884, -144.5}, + {0xd88d, -145.625}, + {0xd89b, -147.375}, + {0xd8da, -155.25}, + {0xd93b, -167.375}, + {0xd982, -176.25}, + {0xd995, -178.625}, + {0xd99d, -179.625}, + {0xd9cf, -185.875}, + {0xdaaf, -213.875}, + {0xdabd, -215.625}, + {0xdb54, -234.5}, + {0xdc10, -260}, + {0xdca1, -296.25}, + {0xdd0a, -322.5}, + {0xdd56, -341.5}, + {0xddcf, -371.75}, + {0xde04, -385}, + {0xde0d, -387.25}, + {0xde3d, -399.25}, + {0xde4f, -403.75}, + {0xde66, -409.5}, + {0xdeae, -427.5}, + {0xdf52, -468.5}, + {0xdf63, -472.75}, + {0xdf6a, -474.5}, + {0xdf77, -477.75}, + {0xdf7b, -478.75}, + {0xdfc5, -497.25}, + {0xdfcf, -499.75}, + {0xdfd2, -500.5}, + {0xdfd8, -502}, + {0xdfe1, -504.25}, + {0xe022, -529}, + {0xe046, -547}, + {0xe092, -585}, + {0xe0b0, -600}, + {0xe0be, -607}, + {0xe0f4, -634}, + {0xe11b, -653.5}, + {0xe19c, -718}, + {0xe213, -777.5}, + {0xe232, -793}, + {0xe25b, -813.5}, + {0xe262, -817}, + {0xe279, -828.5}, + {0xe2cc, -870}, + {0xe2da, -877}, + {0xe326, -915}, + {0xe330, -920}, + {0xe3c3, -993.5}, + {0xe3cc, -998}, + {0xe566, -1382}, + {0xe57e, -1406}, + {0xe5c8, -1480}, + {0xe609, -1545}, + {0xe628, -1576}, + {0xe663, -1635}, + {0xe6ac, -1708}, + {0xe710, -1808}, + {0xe77f, -1919}, + {0xe7e7, -2023}, + {0xe868, -2256}, + {0xe885, -2314}, + {0xe8ea, -2516}, + {0xe919, -2610}, + {0xe92c, -2648}, + {0xea60, -3264}, + {0xeac1, -3458}, + {0xeacb, -3478}, + {0xeb22, -3652}, + {0xeb2c, -3672}, + {0xeb59, -3762}, + {0xeba5, -3914}, + {0xec53, -4428}, + {0xec97, -4700}, + {0xed16, -5208}, + {0xed4a, -5416}, + {0xed69, -5540}, + {0xee14, -6224}, + {0xee59, -6500}, + {0xee8a, -6696}, + {0xee93, -6732}, + {0xeed7, -7004}, + {0xef0b, -7212}, + {0xef59, -7524}, + {0xef61, -7556}, + {0xef67, -7580}, + {0xefb6, -7896}, + {0xf03a, -8656}, + {0xf04e, -8816}, + {0xf05f, -8952}, + {0xf09f, -9464}, + {0xf0c0, -9728}, + {0xf173, -11160}, + {0xf1d7, -11960}, + {0xf225, -12584}, + {0xf2ca, -13904}, + {0xf2d8, -14016}, + {0xf2e5, -14120}, + {0xf317, -14520}, + {0xf35d, -15080}, + {0xf3bd, -15848}, + {0xf3d3, -16024}, + {0xf3e6, -16176}, + {0xf3fb, -16344}, + {0xf477, -18288}, + {0xf4e0, -19968}, + {0xf4e5, -20048}, + {0xf50b, -20656}, + {0xf5a2, -23072}, + {0xf5c1, -23568}, + {0xf634, -25408}, + {0xf651, -25872}, + {0xf68a, -26784}, + {0xf69c, -27072}, + {0xf6ce, -27872}, + {0xf816, -33472}, + {0xf849, -35104}, + {0xf869, -36128}, + {0xf878, -36608}, + {0xf8cf, -39392}, + {0xf90a, -41280}, + {0xf916, -41664}, + {0xf91e, -41920}, + {0xf9c1, -47136}, + {0xfa0a, -49472}, + {0xfa11, -49696}, + {0xfa1d, -50080}, + {0xfa51, -51744}, + {0xfa86, -53440}, + {0xfaac, -54656}, + {0xfb95, -62112}, + {0xfbd1, -64032}, + {0xfbe0, -64512}, + {0xfbf5, -65184}, + {0xfc00, -std::numeric_limits::infinity()}, + {0xfca5, std::numeric_limits::quiet_NaN()}, + {0xfcb9, std::numeric_limits::quiet_NaN()}, + {0xfcc6, std::numeric_limits::quiet_NaN()}, + {0xfd72, std::numeric_limits::quiet_NaN()}, + {0xfd77, std::numeric_limits::quiet_NaN()}, + {0xfda3, std::numeric_limits::quiet_NaN()}, + {0xfe3e, std::numeric_limits::quiet_NaN()}, + {0xfe89, std::numeric_limits::quiet_NaN()}, + {0xfe91, std::numeric_limits::quiet_NaN()}, + {0xfe93, std::numeric_limits::quiet_NaN()}, + {0xfed1, std::numeric_limits::quiet_NaN()}, + {0xff7a, std::numeric_limits::quiet_NaN()}, + {0xffa3, std::numeric_limits::quiet_NaN()}, + }; + return result; +} + +TEST_CASE(check_half_values) +{ + for(auto [x, f] : half_lut()) + { + auto h = migraphx::bit_cast(x); + if(std::isnan(f)) + { + CHECK(std::isnan(h)); + } + else if(std::isinf(f)) + { + CHECK(std::isinf(h)); + CHECK((h < 0) == (f < 0)); + CHECK(bit_equal(x, migraphx::half(f))); + } + else + { + CHECK(bit_equal(x, migraphx::half(f))); + CHECK(migraphx::float_equal(float(h), f)); + } + } +} + +TEST_CASE(check_flows) +{ + // check positive underflow + CHECK(bit_equal(std::numeric_limits::min() * + std::numeric_limits::min(), + migraphx::half(0))); + + // check overflow + CHECK(bit_equal(std::numeric_limits::infinity() + + std::numeric_limits::infinity(), + std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::max() + + std::numeric_limits::max(), + std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::max() / + std::numeric_limits::epsilon(), + std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::max() + + std::numeric_limits::min(), + std::numeric_limits::max())); + + // check negative underflow + CHECK(bit_equal(std::numeric_limits::lowest() + + std::numeric_limits::lowest(), + -std::numeric_limits::infinity())); + CHECK(bit_equal(-std::numeric_limits::infinity() - + std::numeric_limits::infinity(), + -std::numeric_limits::infinity())); + CHECK(bit_equal(std::numeric_limits::lowest() - + std::numeric_limits::min(), + std::numeric_limits::lowest())); +} + +TEST_CASE(test_nan) +{ + float f_qnan = std::numeric_limits::quiet_NaN(); + migraphx::half half_qnan(f_qnan); + EXPECT(half_qnan.is_nan()); + EXPECT(std::isnan(half_qnan)); + + float f_snan = std::numeric_limits::signaling_NaN(); + migraphx::half half_snan(f_snan); + EXPECT(half_snan.is_nan()); + EXPECT(std::isnan(half_snan)); +} + +TEST_CASE(test_bool) +{ + float zero = 0.0; + float two = 2.0; + float other = -0.375; + migraphx::half half_zero(zero); + migraphx::half half_two(two); + migraphx::half half_other(other); + EXPECT(not static_cast(half_zero)); + EXPECT(static_cast(half_two)); + EXPECT(static_cast(half_other)); +} + +TEST_CASE(test_pos_infinity) +{ + float finf = std::numeric_limits::infinity(); + migraphx::half half_inf_1(finf); + CHECK(bit_equal(half_inf_1, std::numeric_limits::infinity())); +} + +TEST_CASE(test_neg_infinity) +{ + float finf = -1.0 * std::numeric_limits::infinity(); + migraphx::half half_neginf_1(finf); + CHECK(bit_equal(half_neginf_1, -std::numeric_limits::infinity())); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); // fp32 max is fp16 inf + migraphx::half half_inf(fmax); + CHECK(bit_equal(half_inf, std::numeric_limits::infinity())); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx::half half_neginf(flowest); + CHECK(bit_equal(half_neginf, -std::numeric_limits::infinity())); +} + +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::half(0.0))); + EXPECT(std::isfinite(migraphx::half(-0.0))); + EXPECT(not std::isfinite(migraphx::half(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::half(-1.0); + auto b = migraphx::half(1.0); + auto c = migraphx::half(0.0); + auto d = migraphx::half(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::half(10.0); + auto f = migraphx::half(-10.0); + EXPECT(e > f); + EXPECT(f < e); + EXPECT(f <= e); + EXPECT(e >= f); + EXPECT(e <= e); + EXPECT(f >= f); + EXPECT(not migraphx::float_equal(f, e)); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::half(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/instruction.cpp b/test/instruction.cpp index 134658e336b..0ee22e13553 100644 --- a/test/instruction.cpp +++ b/test/instruction.cpp @@ -67,4 +67,24 @@ TEST_CASE(check_replace_shape) EXPECT(add->get_shape() == r); } +TEST_CASE(check_replace_dag) +{ + migraphx::module m; + migraphx::shape s{migraphx::shape::float_type, {3, 2}}; + auto input = m.add_parameter("x", s); + auto reduce = m.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0}}}), input); + auto abs = m.add_instruction(migraphx::make_op("abs"), reduce); + auto sin = m.add_instruction(migraphx::make_op("sin"), reduce); + auto add = m.add_instruction(migraphx::make_op("add"), abs, sin); + auto add2 = m.add_instruction(migraphx::make_op("add"), add, reduce); + + reduce->replace(migraphx::make_op("reduce_sum", {{"axes", {1}}})); + + migraphx::shape r{migraphx::shape::float_type, {3, 1}}; + EXPECT(reduce->get_shape() == r); + EXPECT(abs->get_shape() == r); + EXPECT(sin->get_shape() == r); + EXPECT(add->get_shape() == r); + EXPECT(add2->get_shape() == r); +} int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/layout_nhwc.cpp b/test/layout_convolution.cpp similarity index 58% rename from test/layout_nhwc.cpp rename to test/layout_convolution.cpp index 7dae574d113..64e8830d67b 100644 --- a/test/layout_nhwc.cpp +++ b/test/layout_convolution.cpp @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include +#include #include #include #include @@ -32,9 +32,9 @@ #include -void run_pass(migraphx::module& m) +void run_pass(migraphx::module& m, migraphx::layout_convolution lc = {}) { - migraphx::run_passes(m, {migraphx::layout_nhwc{}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(m, {lc, migraphx::dead_code_elimination{}}); } migraphx::operation layout(std::vector permutation = {0, 1, 2, 3}) @@ -47,7 +47,7 @@ migraphx::instruction_ref add_layout_nhwc(migraphx::module& m, migraphx::instruc return m.add_instruction(layout({0, 2, 3, 1}), ins); } -TEST_CASE(conv_relu) +TEST_CASE(auto_conv_nchw) { migraphx::module m1; { @@ -59,9 +59,128 @@ TEST_CASE(conv_relu) {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), x, w); - m1.add_instruction(migraphx::make_op("relu"), conv); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); } + migraphx::module m2 = m1; run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(auto_conv_nhwc) +{ + auto transpose = migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 16, 16, 8}}); + auto xtranspose = m1.add_instruction(transpose, x); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {16, 3, 3, 8}})); + auto wtranspose = m1.add_instruction(transpose, w); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + xtranspose, + wtranspose); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(auto_conv_mixed) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {3, 3, 16, 8}})); + auto wtranspose = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wtranspose); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m2.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {3, 3, 16, 8}})); + auto wtranspose = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto wlayout = m2.add_instruction( + migraphx::make_op("layout", {{"permutation", {0, 1, 2, 3}}}), wtranspose); + auto conv = m2.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wlayout); + auto relu = m2.add_instruction(migraphx::make_op("relu"), conv); + m2.add_return({relu}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(auto_quant_conv_mixed) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}}); + auto w = + m1.add_literal(migraphx::generate_literal({migraphx::shape::int8_type, {3, 3, 16, 8}})); + auto wtranspose = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto conv = m1.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wtranspose); + auto relu = m1.add_instruction(migraphx::make_op("relu"), conv); + m1.add_return({relu}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}}); + auto w = + m2.add_literal(migraphx::generate_literal({migraphx::shape::int8_type, {3, 3, 16, 8}})); + auto wtranspose = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 3, 0, 1}}}), w); + auto wlayout = m2.add_instruction( + migraphx::make_op("layout", {{"permutation", {0, 1, 2, 3}}}), wtranspose); + auto conv = m2.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + wlayout); + auto relu = m2.add_instruction(migraphx::make_op("relu"), conv); + m2.add_return({relu}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(nhwc_conv_relu) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}); + auto w = m1.add_literal( + migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 3, 3}})); + auto conv = m1.add_instruction( + migraphx::make_op("convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + m1.add_instruction(migraphx::make_op("relu"), conv); + } + run_pass(m1, {.channels_last = true}); migraphx::module m2; { @@ -81,7 +200,7 @@ TEST_CASE(conv_relu) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(conv_add) +TEST_CASE(nhwc_conv_add) { migraphx::module m1; { @@ -99,7 +218,7 @@ TEST_CASE(conv_add) y); m1.add_instruction(migraphx::make_op("add"), conv, b); } - run_pass(m1); + run_pass(m1, {.channels_last = true}); migraphx::module m2; { @@ -114,7 +233,7 @@ TEST_CASE(conv_add) {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), x, w); - auto b = m2.add_instruction( + auto b = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), y); auto add = m2.add_instruction(migraphx::make_op("add"), conv, b); @@ -123,7 +242,49 @@ TEST_CASE(conv_add) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(conv_conv) +TEST_CASE(nhwc_quant_conv_add) +{ + migraphx::module m1; + { + auto x = m1.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}}); + auto w = + m1.add_literal(migraphx::generate_literal({migraphx::shape::int8_type, {16, 8, 3, 3}})); + auto y = m1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {16}})); + auto conv = m1.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + auto b = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + y); + m1.add_instruction(migraphx::make_op("add"), conv, b); + } + run_pass(m1, {.channels_last = true}); + + migraphx::module m2; + { + auto x = add_layout_nhwc( + m2, m2.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 16, 16}})); + auto w = add_layout_nhwc(m2, + m2.add_literal(migraphx::generate_literal( + {migraphx::shape::int8_type, {16, 8, 3, 3}}))); + auto y = m2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {16}})); + auto conv = m2.add_instruction( + migraphx::make_op("quant_convolution", + {{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}), + x, + w); + auto b = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), + y); + auto add = m2.add_instruction(migraphx::make_op("add"), conv, b); + m2.add_instruction(layout(), add); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(nhwc_conv_conv) { migraphx::module m1; { @@ -149,7 +310,7 @@ TEST_CASE(conv_conv) auto relu2 = m1.add_instruction(migraphx::make_op("relu"), add2); m1.add_return({relu2}); } - run_pass(m1); + run_pass(m1, {.channels_last = true}); migraphx::module m2; { @@ -182,7 +343,7 @@ TEST_CASE(conv_conv) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(conv_reduce) +TEST_CASE(nhwc_conv_reduce) { migraphx::module m1; { @@ -201,7 +362,7 @@ TEST_CASE(conv_reduce) auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), reduce); m1.add_return({squeeze}); } - run_pass(m1); + run_pass(m1, {.channels_last = true}); migraphx::module m2; { diff --git a/test/onnx/.onnxrt-commit b/test/onnx/.onnxrt-commit index 7da7b2e4909..39722ee5a0e 100644 --- a/test/onnx/.onnxrt-commit +++ b/test/onnx/.onnxrt-commit @@ -1 +1 @@ -8fbbf2fd4feba517e4cef84086f26fc2a9eeb218 +4d614e15bd9e6949bc3066754791da403e00d66c diff --git a/test/onnx/verify/negativelogliklihood_kd_dim_weighted.cpp b/test/onnx/verify/negativelogliklihood_kd_dim_weighted.cpp index 06865e637b2..69de5d2c15f 100644 --- a/test/onnx/verify/negativelogliklihood_kd_dim_weighted.cpp +++ b/test/onnx/verify/negativelogliklihood_kd_dim_weighted.cpp @@ -170,7 +170,7 @@ TEST_CASE(negativeloglikelihoodloss_kd_mean_reduction_weighted_test) pp["2"] = migraphx::argument(weight_shape, weight_data.data()); auto result = p.eval(pp).back(); - std::vector result_vector; + std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); std::vector gold = {half{-35.266666666666666}}; EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); @@ -200,7 +200,7 @@ TEST_CASE(negativeloglikelihoodloss_kd_mean_reduction_weighted_test2) migraphx::shape label_shape{migraphx::shape::int32_type, {2, 2}}; std::vector label_data = {2, 1, 0, 2}; migraphx::shape weight_shape{migraphx::shape::half_type, {3}}; - std::vector weight_data = {half(0.2), half(0.3), half(0.1)}; + std::vector weight_data = {half(0.2), half(0.3), half(0.1)}; migraphx::parameter_map pp; pp["0"] = migraphx::argument(score_shape, score_data.data()); @@ -208,7 +208,7 @@ TEST_CASE(negativeloglikelihoodloss_kd_mean_reduction_weighted_test2) pp["2"] = migraphx::argument(weight_shape, weight_data.data()); auto result = p.eval(pp).back(); - std::vector result_vector; + std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); std::vector gold = {half{-1.5714285714285714}}; EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); diff --git a/test/onnx/verify/softmaxcrossentropyloss_kd_dim_weighted.cpp b/test/onnx/verify/softmaxcrossentropyloss_kd_dim_weighted.cpp index 14b5a0da963..34fb82c9070 100644 --- a/test/onnx/verify/softmaxcrossentropyloss_kd_dim_weighted.cpp +++ b/test/onnx/verify/softmaxcrossentropyloss_kd_dim_weighted.cpp @@ -180,7 +180,7 @@ TEST_CASE(softmaxcrossentropyloss_kd_mean_reduction_weighted_test) pp["2"] = migraphx::argument(weight_shape, weight_data.data()); auto result = p.eval(pp).back(); - std::vector result_vector; + std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); std::vector gold = {half{1.38629436}}; EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); @@ -207,7 +207,7 @@ TEST_CASE(softmaxcrossentropyloss_kd_mean_reduction_uneven_weighted_test) pp["2"] = migraphx::argument(weight_shape, weight_data.data()); auto result = p.eval(pp).back(); - std::vector result_vector; + std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); std::vector gold = {half{1.38629436}}; diff --git a/test/quantization.cpp b/test/quantization.cpp index 3c8968cfdf4..6eafcca377c 100644 --- a/test/quantization.cpp +++ b/test/quantization.cpp @@ -30,8 +30,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -261,8 +261,9 @@ TEST_CASE(param_add_sub) }; auto p0 = create_program_float(); - migraphx::run_passes( - p0, {migraphx::quantize_fp16_pass{{"all"}}, migraphx::dead_code_elimination{}}); + migraphx::run_passes(p0, + {migraphx::truncate_float_pass{{"all"}, migraphx::shape::half_type}, + migraphx::dead_code_elimination{}}); EXPECT(p0 == create_program_fp16()); auto p1 = create_program_float(); @@ -669,7 +670,6 @@ TEST_CASE(dot_float) auto pb = mm->add_parameter("b", sb); auto zp = mm->add_literal(static_cast(0)); auto scale = mm->add_literal(10.0f); - auto zp_out = mm->add_literal(std::int32_t{0}); auto scale_a = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale); auto zp_a = @@ -684,10 +684,7 @@ TEST_CASE(dot_float) auto scale_mb = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale); mm->add_return({r}); return p; @@ -704,11 +701,11 @@ TEST_CASE(dot_float) migraphx::dead_code_elimination{}}); auto qp = create_int8_quantized_prog(); - EXPECT(p == qp); + EXPECT(p.sort() == qp.sort()); optimize_prog_int8(p); auto op = create_int8_optimized_prog(); - EXPECT(p == op); + EXPECT(p.sort() == op.sort()); } TEST_CASE(dot_double_2args) @@ -784,11 +781,7 @@ TEST_CASE(dot_double_2args) migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale_b_lit); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb); - auto zp_out = mm->add_literal(std::int32_t{0}); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale); mm->add_return({r}); return p; }; @@ -855,7 +848,6 @@ TEST_CASE(dot_half_1arg) auto zp = mm->add_literal(static_cast(0)); auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0})); - auto zp_out = mm->add_literal(std::int32_t{0}); auto scale = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit); zp = @@ -863,10 +855,7 @@ TEST_CASE(dot_half_1arg) auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale); mm->add_return({r}); return p; }; @@ -922,11 +911,7 @@ TEST_CASE(conv_float) migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale_lit); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb); - auto zp_out = mm->add_literal(std::int32_t{0}); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale); mm->add_return({r}); return p; @@ -1004,11 +989,7 @@ TEST_CASE(conv_half) migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale_lit); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb); - auto zp_out = mm->add_literal(std::int32_t{0}); - zp_out = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), zp_out); - auto r = - mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale, zp_out); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale); mm->add_return({r}); return p; @@ -1256,10 +1237,7 @@ TEST_CASE(int8_subgraph) auto s1_mb = then_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1); auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb); - auto zp_out = then_mod->add_literal(std::int32_t{0}); - zp_out = then_mod->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), zp_out); - auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so, zp_out); + auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so); then_mod->add_return({r}); migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}}; @@ -1285,13 +1263,8 @@ TEST_CASE(int8_subgraph) auto ssw_mb = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}), ssw_lit); - auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb); - auto zp1_out = else_mod->add_literal(std::int32_t{0}); - zp1_out = else_mod->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}), - zp1_out); - auto r1 = - else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1, zp1_out); + auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb); + auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1); else_mod->add_return({r1}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index 61ec6fa0ae9..c3c50cb4172 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -41,7 +41,10 @@ namespace match = migraphx::match; bool is_convolution(const migraphx::instruction& ins) { return ins.name() == "convolution"; } bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot"; } -void run_pass(migraphx::module& m) { run_passes(m, {migraphx::simplify_qdq{}}); } +void run_pass(migraphx::module& m) +{ + run_passes(m, {migraphx::simplify_qdq{}, migraphx::dead_code_elimination{}}); +} void run_cse(migraphx::module& m) { run_passes(m, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}}); @@ -162,7 +165,7 @@ TEST_CASE(qdq_different_scales) auto t2 = m1.add_parameter("t2", sh2); auto scale1 = m1.add_literal(0.5f); auto scale2 = m1.add_literal(0.4f); - auto zero = m1.add_literal(std::int8_t{0}); + auto zero = m1.add_literal(std::int8_t{1}); auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero); auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale2, zero); @@ -210,8 +213,7 @@ TEST_CASE(dot) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -262,8 +264,7 @@ TEST_CASE(dot_fp16) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); auto d3h = m2.add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), d3); m2.add_return({d3h}); @@ -308,8 +309,7 @@ TEST_CASE(dot_multi_scale) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale = add_scale_mul(m2, scale1, scale2, 0, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -353,8 +353,7 @@ TEST_CASE(dot_broadcasted) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -398,8 +397,7 @@ TEST_CASE(dot_transposed) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -441,8 +439,7 @@ TEST_CASE(dot_reshaped) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -496,8 +493,7 @@ TEST_CASE(dot_multi_scale_all_skip_post_dq_ops) auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb); auto out_scale = add_scale_mul(m2, scale1, scale2, 2, 3, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -757,8 +753,7 @@ TEST_CASE(dot_add) auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab); m2.add_return({add}); } @@ -811,13 +806,11 @@ TEST_CASE(dot_add_multiple_dq_use) auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto dot_1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1_tmbc, q2); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot_1->get_shape().lens()); - auto out_zp = init_zero_point(m2, dot_1); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot_1, out_scale, out_zp); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot_1, out_scale); auto d3_q = add_quantize_op(m2, "quantizelinear", d3, scale, zero); auto dot_2 = m2.add_instruction(migraphx::make_op("quant_dot"), d3_q, q1); auto out_scale_2 = add_scale_mul(m2, scale, scale, 1, 1, dot_2->get_shape().lens()); - auto out_zp_2 = init_zero_point(m2, dot_2); - auto d4 = add_quantize_op(m2, "dequantizelinear", dot_2, out_scale_2, out_zp_2); + auto d4 = add_quantize_op(m2, "dequantizelinear", dot_2, out_scale_2); auto add = m2.add_instruction(migraphx::make_op("add"), d4, t1); m2.add_return({add}); } @@ -868,8 +861,7 @@ TEST_CASE(conv) q1, weights); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens()); - auto out_zp = init_zero_point(m2, c1); - auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale, out_zp); + auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale); m2.add_return({d6}); } @@ -986,8 +978,7 @@ TEST_CASE(conv_multi_scale) q_inp, weights); auto out_scale = add_scale_mul(m2, inp_scale, w_scale, 1, 1, c1->get_shape().lens()); - auto out_zp = init_zero_point(m2, c1); - auto d1 = add_quantize_op(m2, "dequantizelinear", c1, out_scale, out_zp); + auto d1 = add_quantize_op(m2, "dequantizelinear", c1, out_scale); m2.add_return({d1}); } @@ -1027,9 +1018,8 @@ TEST_CASE(conv_multi_scale_unsupported_axis) auto input = m2.add_parameter("input", s7); auto weights = m2.add_parameter("weights", s4); auto scale = m2.add_literal(migraphx::generate_literal(s8, 0)); - auto zero = m2.add_literal(std::int8_t{0}); - auto d1 = add_quantize_op(m2, "dequantizelinear", weights, scale, zero); + auto d1 = add_quantize_op(m2, "dequantizelinear", weights, scale); auto c1 = m2.add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}, {"stride", {1, 1}}, @@ -1085,9 +1075,8 @@ TEST_CASE(conv_bias_add) auto bias = m2.add_parameter("bias", s6); auto scale = m2.add_literal(0.5f); auto zero = m2.add_literal(std::int8_t{0}); - auto zero32 = m2.add_literal(std::int32_t{0}); - auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); + auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", {{"padding", {0, 0, 0, 0}}, @@ -1098,8 +1087,7 @@ TEST_CASE(conv_bias_add) q1, weights); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens()); - auto out_zp = init_zero_point(m2, c1); - auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale, out_zp); + auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale); auto b1 = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1); @@ -1176,10 +1164,9 @@ TEST_CASE(conv_pooling_dot) auto input = m2.add_parameter("input", s7); auto scale = m2.add_literal(0.5f); auto zero = m2.add_literal(std::int8_t{0}); - auto zero32 = m2.add_literal(std::int32_t{0}); - auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); - auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero); + auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale); + auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", {{"padding", {0, 0, 0, 0}}, @@ -1190,8 +1177,7 @@ TEST_CASE(conv_pooling_dot) q1, weights); auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens()); - auto out_zp1 = init_zero_point(m2, c1); - auto d5 = add_quantize_op(m2, "dequantizelinear", c1, out_scale1, out_zp1); + auto d5 = add_quantize_op(m2, "dequantizelinear", c1, out_scale1); auto bc1 = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); @@ -1208,8 +1194,7 @@ TEST_CASE(conv_pooling_dot) auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db); auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 0, dot->get_shape().lens()); - auto out_zp2 = init_zero_point(m2, dot); - auto d9 = add_quantize_op(m2, "dequantizelinear", dot, out_scale2, out_zp2); + auto d9 = add_quantize_op(m2, "dequantizelinear", dot, out_scale2); auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1); @@ -1517,22 +1502,21 @@ TEST_CASE(dot_reused) auto w2 = m2.add_parameter("w2", sh); auto scale = m2.add_literal(0.5f); auto zero = m2.add_literal(std::int8_t{0}); - auto zero2 = m2.add_literal(std::int32_t{0}); auto q1 = add_quantize_op(m2, "quantizelinear", x, scale, zero); auto q2 = add_quantize_op(m2, "quantizelinear", w1, scale, zero); auto dot1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens()); - auto d1 = add_quantize_op(m2, "dequantizelinear", dot1, out_scale1, zero2); + auto d1 = add_quantize_op(m2, "dequantizelinear", dot1, out_scale1); auto add1 = m2.add_instruction(migraphx::make_op("add"), d1, y); auto q3 = add_quantize_op(m2, "quantizelinear", add1, scale, zero); auto q4 = add_quantize_op(m2, "quantizelinear", w2, scale, zero); auto dot2 = m2.add_instruction(migraphx::make_op("quant_dot"), q3, q4); auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens()); - auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2, zero2); - auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1], q3->inputs()[2]); + auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2); + auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1]); auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3); m2.add_return({add2}); }