diff --git a/src/include/migraphx/op/dequantizelinear.hpp b/src/include/migraphx/op/dequantizelinear.hpp index 3cd2d89fd96..5f0b7a6e0c1 100644 --- a/src/include/migraphx/op/dequantizelinear.hpp +++ b/src/include/migraphx/op/dequantizelinear.hpp @@ -49,10 +49,13 @@ struct dequantizelinear std::string name() const { return "dequantizelinear"; } shape compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.same_dims().has(2, 3); - if(inputs.size() == 3 and inputs[0].type() != inputs[2].type()) + if(inputs.size() == 3) { - MIGRAPHX_THROW("DEQUANTIZELINEAR: Zero point and input should be the same type."); + check_shapes{inputs, *this}.same_dims().has(2, 3); + if(inputs[0].type() != inputs[2].type()) + { + MIGRAPHX_THROW("DEQUANTIZELINEAR: Zero point and input should be the same type."); + } } return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()}; } diff --git a/test/onnx/verify/dequantizelinear_simple_test.cpp b/test/onnx/verify/dequantizelinear_simple_test.cpp new file mode 100644 index 00000000000..960b48ae0b4 --- /dev/null +++ b/test/onnx/verify/dequantizelinear_simple_test.cpp @@ -0,0 +1,82 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +TEST_CASE(dequantizelinear_simple_no_zp_test) +{ + migraphx::program p = read_onnx("dequantizelinear_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::int8_type, {5}}; + std::vector x = {4, 8, 20, 2, 0}; + + migraphx::shape scale_shape{migraphx::shape::float_type, {1}, {1}}; + std::vector scale = {2.0f}; + + migraphx::parameter_map pm; + pm["0"] = migraphx::argument{x_shape, x.data()}; + pm["1"] = migraphx::argument{scale_shape, scale.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {5}}); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {8.0f, 16.0f, 40.0f, 4.0f, 0.0f}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(dequantizelinear_simple_with_zp_test) +{ + migraphx::program p = read_onnx("dequantizelinear_zero_point_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape x_shape{migraphx::shape::int8_type, {5}}; + std::vector x = {4, 8, 20, 2, 0}; + + migraphx::shape scale_shape{migraphx::shape::float_type, {1}, {1}}; + std::vector scale = {2.0f}; + + migraphx::shape zp_shape{migraphx::shape::int8_type, {1}, {1}}; + std::vector zp = {20}; + + migraphx::parameter_map pm; + pm["0"] = migraphx::argument{x_shape, x.data()}; + pm["1"] = migraphx::argument{scale_shape, scale.data()}; + pm["2"] = migraphx::argument{zp_shape, zp.data()}; + + auto result = p.eval(pm).back(); + EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {5}}); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {-32.0f, -24.0f, 0.0f, -36.0f, -40.0f}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +}