diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp new file mode 100644 index 00000000000..af9f09790aa --- /dev/null +++ b/src/onnx/parse_matmulnbits.cpp @@ -0,0 +1,192 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include "migraphx/errors.hpp" +#include "migraphx/instruction_ref.hpp" +#include "migraphx/onnx/onnx_parser.hpp" +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_matmulnbits : op_parser +{ + std::vector operators() const { return {{"MatMulNBits"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + const std::vector& args) const + { + const size_t n = parse_attribute(parser, info, "N"); + const size_t k = parse_attribute(parser, info, "K"); + const size_t bits = parse_attribute(parser, info, "bits"); + const size_t block_size = parse_attribute(parser, info, "block_size"); + + if(bits != 4) + MIGRAPHX_THROW("MatMulNBits: bits only supported for value of 4, actual value " + + std::to_string(bits)); + + if(block_size < 16 or (block_size & (block_size - 1)) != 0) + MIGRAPHX_THROW("MatMulNBits: block_size must be a power of 2 and >=16, actual value " + + std::to_string(block_size)); + + const size_t n_blocks_per_col = (k + block_size - 1) / block_size; + const size_t blob_size = std::ceil(block_size * bits / 8.0f); + + std::vector expected_b_lens{n, n_blocks_per_col, blob_size}; + if(args[1]->get_shape().lens() != expected_b_lens) + MIGRAPHX_THROW("MatMulNBits: Input B does not match expected dims: " + + to_string_range(expected_b_lens) + + ". Actual dims: " + to_string_range(args[1]->get_shape().lens())); + + std::vector expected_scales_lens{n * n_blocks_per_col}; + if(args[2]->get_shape().lens() != expected_scales_lens) + MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " + + to_string_range(expected_scales_lens) + + ". Actual dims: " + to_string_range(args[2]->get_shape().lens())); + + if(args.size() > 3) + { + std::vector expected_zp_lens{ + static_cast(n * std::ceil(n_blocks_per_col * bits / 8.0f))}; + if(args[3]->get_shape().lens() != expected_zp_lens) + MIGRAPHX_THROW("MatMulNBits: Input zero_points does not match expected dims: " + + to_string_range(expected_zp_lens) + + ". Actual dims: " + to_string_range(args[3]->get_shape().lens())); + } + + auto b = dequantize_b(info, n, k, block_size, args); + b = info.add_instruction(make_op("transpose", {{"permutation", {1, 0}}}), b); + return matmul(info, args[0], b); + } + + private: + int parse_attribute(const onnx_parser& parser, + onnx_parser::node_info& info, + const std::string& attribute_name) const + { + if(not contains(info.attributes, attribute_name)) + MIGRAPHX_THROW("MatMulNBits: Attribute " + attribute_name + + " required, but is missing"); + + return parser.parse_value(info.attributes[attribute_name]).at(); + } + + instruction_ref dequantize_b(onnx_parser::node_info& info, + int n, + int k, + int block_size, + const std::vector& args) const + { + auto b = unpack(info, n, k, args[1]); + + auto n_blocks_per_col = (k + block_size - 1) / block_size; + auto scales = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), args[2]); + scales = prepare_blockwise_dq_arg(info, n, k, block_size, scales); + + instruction_ref zp; + if(args.size() == 4) + { + zp = unpack(info, n, n_blocks_per_col, args[3]); + zp = prepare_blockwise_dq_arg(info, n, k, block_size, zp); + } + else + { + zp = info.add_literal(literal{shape{shape::uint8_type, {1}}, {8}}); + zp = info.add_instruction( + make_op("multibroadcast", {{"out_lens", b->get_shape().lens()}}), zp); + } + return info.add_instruction(make_op("dequantizelinear"), {b, scales, zp}); + } + + instruction_ref unpack(onnx_parser::node_info& info, int n, int dim1, instruction_ref x) const + { + x = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), x); + x = info.add_instruction(make_op("unpack_int4"), x); + if(x->get_shape().lens()[1] > dim1) + { + x = info.add_instruction( + make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {dim1}}}), x); + } + return x; + } + + instruction_ref prepare_blockwise_dq_arg( + onnx_parser::node_info& info, int n, int k, int block_size, instruction_ref x) const + { + x = info.add_instruction(make_op("unsqueeze", {{"axes", {2}}}), x); + + auto bc_lens = x->get_shape().lens(); + bc_lens[2] = block_size; + x = info.add_instruction(make_op("multibroadcast", {{"out_lens", bc_lens}}), x); + x = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), x); + + // Detect runt block + if(x->get_shape().lens()[1] > k) + { + x = info.add_instruction( + make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {k}}}), x); + } + + return x; + } + + instruction_ref matmul(onnx_parser::node_info& info, instruction_ref a, instruction_ref b) const + { + const auto a_rank = a->get_shape().ndim(); + // B is always rank 2: + // If A is rank 1, unsqueeze A to make it rank 2 to prepare for dot + // If A is rank 2, just a regular dot + // If A is rank > 2, broadcast B to match outer dims of A to prepare for dot + if(a_rank == 1) + { + a = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), a); + } + else if(a_rank > 2) + { + auto b_lens = b->get_shape().lens(); + auto b_bc_lens = a->get_shape().lens(); + std::copy(b_lens.begin(), b_lens.end(), b_bc_lens.end() - 2); + b = info.add_instruction(make_op("multibroadcast", {{"out_lens", b_bc_lens}}), b); + } + + auto dot = info.add_instruction(make_op("dot"), a, b); + + if(a_rank == 1) + dot = info.add_instruction( + make_op("squeeze", {{"axes", {dot->get_shape().ndim() - 2}}}), dot); + + return dot; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index d58465185d9..6a3a403799c 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -123,6 +123,7 @@ struct miopen_apply add_select_module_op(); add_reshape_lazy_op(); add_scan_slice_op(); + add_unpack_int4_op(); } void copy_params() const @@ -527,6 +528,26 @@ struct miopen_apply ins, mod->insert_instruction(ins, ins->get_operator(), inputs)); }); } + + void add_unpack_int4_op() + { + apply_map.emplace("unpack_int4", [=](instruction_ref ins) { + auto inputs = ins->inputs(); + auto output = insert_allocation(ins, ins->get_shape()); + std::vector cpu_inputs; + auto gpu_inputs = ins->inputs(); + std::transform( + gpu_inputs.begin(), gpu_inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) { + return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in); + }); + cpu_inputs.front() = + mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs); + auto cpu_out = mod->insert_instruction(ins, ins->get_operator(), cpu_inputs); + auto gpu_out = + mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output); + return mod->replace_instruction(ins, gpu_out); + }); + } }; void lowering::apply(module_pass_manager& mpm) const diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 66ab038cfc7..60113e31ba2 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -9020,6 +9020,150 @@ def qlinearmatmul_3D_test(): [sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c]) +@onnx_test() +def matmulnbits_mm_test(): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, + [2, 16]) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, + [4, 1, 8]) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, [4]) + zp = onnx.helper.make_tensor_value_info("zp", onnx.TensorProto.UINT8, [4]) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [2, 4]) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales", "zp"], + outputs=["c"], + bits=4, + block_size=16, + K=16, + N=4, + domain='com.microsoft') + return ([node], [a, b, scales, zp], [c]) + + +@onnx_test() +def matmulnbits_mm2_test(): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, + [2, 33]) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, + [2, 3, 8]) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, [6]) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [2, 2]) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales"], + outputs=["c"], + bits=4, + block_size=16, + K=33, + N=2, + domain='com.microsoft') + return ([node], [a, b, scales], [c]) + + +@onnx_test() +def matmulnbits_vm_test(): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, [20]) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, + [3, 2, 8]) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, [6]) + zp = onnx.helper.make_tensor_value_info("zp", onnx.TensorProto.UINT8, [3]) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales", "zp"], + outputs=["c"], + bits=4, + block_size=16, + K=20, + N=3, + domain='com.microsoft') + return ([node], [a, b, scales, zp], [c]) + + +@onnx_test() +def matmulnbits_bmm_test(): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, + [2, 3, 8]) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, + [2, 1, 8]) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, [2]) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, + [2, 3, 2]) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales"], + outputs=["c"], + bits=4, + block_size=16, + K=8, + N=2, + domain='com.microsoft') + return ([node], [a, b, scales], [c]) + + +def matmulnbits_negative_test(bits=4, + block_size=16, + a_dims=[2, 16], + b_dims=[4, 1, 8], + scales_dims=[4], + zp_dims=[4], + out_dims=[2, 4]): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, a_dims) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, b_dims) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, + scales_dims) + zp = onnx.helper.make_tensor_value_info("zp", onnx.TensorProto.UINT8, + zp_dims) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, + out_dims) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales", "zp"], + outputs=["c"], + bits=bits, + block_size=block_size, + K=16, + N=4, + domain='com.microsoft') + return ([node], [a, b, scales, zp], [c]) + + +@onnx_test() +def matmulnbits_invalid_bits_value_test(): + return matmulnbits_negative_test(bits=5) + + +@onnx_test() +def matmulnbits_block_size_too_small_test(): + return matmulnbits_negative_test(block_size=8) + + +@onnx_test() +def matmulnbits_block_size_not_power_of_two_test(): + return matmulnbits_negative_test(block_size=20) + + +@onnx_test() +def matmulnbits_invalid_b_dims_test(): + return matmulnbits_negative_test(b_dims=[4, 2, 8]) + + +@onnx_test() +def matmulnbits_invalid_scales_dims_test(): + return matmulnbits_negative_test(scales_dims=[3]) + + +@onnx_test() +def matmulnbits_invalid_zp_dims_test(): + return matmulnbits_negative_test(zp_dims=[5]) + + @onnx_test() def qlinearmul_test(): a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64]) diff --git a/test/onnx/matmulnbits_block_size_not_power_of_two_test.onnx b/test/onnx/matmulnbits_block_size_not_power_of_two_test.onnx new file mode 100644 index 00000000000..52dfbe82146 --- /dev/null +++ b/test/onnx/matmulnbits_block_size_not_power_of_two_test.onnx @@ -0,0 +1,32 @@ + ,matmulnbits_block_size_not_power_of_two_test:€ +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft,matmulnbits_block_size_not_power_of_two_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_block_size_too_small_test.onnx b/test/onnx/matmulnbits_block_size_too_small_test.onnx new file mode 100644 index 00000000000..e4fc07fb3de --- /dev/null +++ b/test/onnx/matmulnbits_block_size_too_small_test.onnx @@ -0,0 +1,32 @@ + %matmulnbits_block_size_too_small_test:ù +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft%matmulnbits_block_size_too_small_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_bmm_test.onnx b/test/onnx/matmulnbits_bmm_test.onnx new file mode 100644 index 00000000000..3e792bf7d5d --- /dev/null +++ b/test/onnx/matmulnbits_bmm_test.onnx @@ -0,0 +1,29 @@ + matmulnbits_bmm_test:Ú +a +a +b +scalesc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoftmatmulnbits_bmm_testZ +a + + + +Z +b + + + +Z +scales + + +b +c + + + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_invalid_b_dims_test.onnx b/test/onnx/matmulnbits_invalid_b_dims_test.onnx new file mode 100644 index 00000000000..8e33dafc2e9 --- /dev/null +++ b/test/onnx/matmulnbits_invalid_b_dims_test.onnx @@ -0,0 +1,32 @@ + matmulnbits_invalid_b_dims_test:ó +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoftmatmulnbits_invalid_b_dims_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_invalid_bits_value_test.onnx b/test/onnx/matmulnbits_invalid_bits_value_test.onnx new file mode 100644 index 00000000000..e084c8002a5 --- /dev/null +++ b/test/onnx/matmulnbits_invalid_bits_value_test.onnx @@ -0,0 +1,32 @@ + #matmulnbits_invalid_bits_value_test:÷ +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft#matmulnbits_invalid_bits_value_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_invalid_scales_dims_test.onnx b/test/onnx/matmulnbits_invalid_scales_dims_test.onnx new file mode 100644 index 00000000000..675e3cfc26f --- /dev/null +++ b/test/onnx/matmulnbits_invalid_scales_dims_test.onnx @@ -0,0 +1,32 @@ + $matmulnbits_invalid_scales_dims_test:ø +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft$matmulnbits_invalid_scales_dims_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_invalid_zp_dims_test.onnx b/test/onnx/matmulnbits_invalid_zp_dims_test.onnx new file mode 100644 index 00000000000..63242a15592 --- /dev/null +++ b/test/onnx/matmulnbits_invalid_zp_dims_test.onnx @@ -0,0 +1,32 @@ +  matmulnbits_invalid_zp_dims_test:ô +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft matmulnbits_invalid_zp_dims_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_mm2_test.onnx b/test/onnx/matmulnbits_mm2_test.onnx new file mode 100644 index 00000000000..2c2283f08c5 --- /dev/null +++ b/test/onnx/matmulnbits_mm2_test.onnx @@ -0,0 +1,27 @@ + matmulnbits_mm2_test:Ò +a +a +b +scalesc" MatMulNBits* +K! * +N * +bits * + +block_size : com.microsoftmatmulnbits_mm2_testZ +a +  + +!Z +b + + + +Z +scales + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_mm_test.onnx b/test/onnx/matmulnbits_mm_test.onnx new file mode 100644 index 00000000000..14efed868cd --- /dev/null +++ b/test/onnx/matmulnbits_mm_test.onnx @@ -0,0 +1,32 @@ + matmulnbits_mm_test:ç +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoftmatmulnbits_mm_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_vm_test.onnx b/test/onnx/matmulnbits_vm_test.onnx new file mode 100644 index 00000000000..e526131b288 --- /dev/null +++ b/test/onnx/matmulnbits_vm_test.onnx @@ -0,0 +1,32 @@ + matmulnbits_vm_test:ß +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoftmatmulnbits_vm_testZ +a + + +Z +b + + + +Z +scales + + +Z +zp + + +b +c + + +B \ No newline at end of file diff --git a/test/onnx/parse/matmulnbits_negative_tests.cpp b/test/onnx/parse/matmulnbits_negative_tests.cpp new file mode 100644 index 00000000000..e5e8ad44336 --- /dev/null +++ b/test/onnx/parse/matmulnbits_negative_tests.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulnbits_invalid_bits_value_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_invalid_bits_value_test.onnx"); })); +} + +TEST_CASE(matmulnbits_block_size_too_small_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_block_size_too_small_test.onnx"); })); +} + +TEST_CASE(matmulnbits_block_size_not_power_of_two_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_block_size_not_power_of_two_test.onnx"); })); +} + +TEST_CASE(matmulnbits_invalid_b_dims_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_invalid_b_dims_test.onnx"); })); +} + +TEST_CASE(matmulnbits_invalid_scales_dims_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_invalid_scales_dims_test.onnx"); })); +} + +TEST_CASE(matmulnbits_invalid_zp_dims_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_invalid_zp_dims_test.onnx"); })); +} diff --git a/test/onnx/parse/matmulnbits_tests.cpp b/test/onnx/parse/matmulnbits_tests.cpp new file mode 100644 index 00000000000..f4bed4f5868 --- /dev/null +++ b/test/onnx/parse/matmulnbits_tests.cpp @@ -0,0 +1,176 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "migraphx/make_op.hpp" +#include + +TEST_CASE(matmulnbits_mm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {2, 16}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::uint8_type, {4, 1, 8}}); + auto scales = mm->add_parameter("scales", migraphx::shape{migraphx::shape::float_type, {4}}); + auto zp = mm->add_parameter("zp", migraphx::shape{migraphx::shape::uint8_type, {4}}); + + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), scales); + scales = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scales); + scales = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 16}}}), + scales); + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), scales); + + zp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), zp); + zp = mm->add_instruction(migraphx::make_op("unpack_int4"), zp); + zp = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), zp); + zp = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), zp); + zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 16}}}), zp); + zp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), zp); + + b = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), b); + b = mm->add_instruction(migraphx::make_op("unpack_int4"), b); + b = mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scales, zp); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + mm->add_instruction(migraphx::make_op("dot"), a, b); + + auto prog = optimize_onnx("matmulnbits_mm_test.onnx"); + + p.sort(); + prog.sort(); + EXPECT(p == prog); +} + +TEST_CASE(matmulnbits_mm2_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {2, 33}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 8}}); + auto scales = mm->add_parameter("scales", migraphx::shape{migraphx::shape::float_type, {6}}); + + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), scales); + scales = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scales); + scales = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 16}}}), + scales); + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), scales); + scales = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {33}}}), scales); + + auto zp = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::uint8_type, {1}}, {8}}); + zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 33}}}), zp); + + b = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), b); + b = mm->add_instruction(migraphx::make_op("unpack_int4"), b); + b = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {33}}}), b); + b = mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scales, zp); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + mm->add_instruction(migraphx::make_op("dot"), a, b); + + auto prog = optimize_onnx("matmulnbits_mm2_test.onnx"); + + p.sort(); + prog.sort(); + EXPECT(p == prog); +} + +TEST_CASE(matmulnbits_vm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {20}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::uint8_type, {3, 2, 8}}); + auto scales = mm->add_parameter("scales", migraphx::shape{migraphx::shape::float_type, {6}}); + auto zp = mm->add_parameter("zp", migraphx::shape{migraphx::shape::uint8_type, {3}}); + + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), scales); + scales = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scales); + scales = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 16}}}), + scales); + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), scales); + scales = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {20}}}), scales); + + zp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), zp); + zp = mm->add_instruction(migraphx::make_op("unpack_int4"), zp); + zp = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), zp); + zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 16}}}), zp); + zp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), zp); + zp = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {20}}}), zp); + + b = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), b); + b = mm->add_instruction(migraphx::make_op("unpack_int4"), b); + b = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {20}}}), b); + b = mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scales, zp); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + + a = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), a); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), dot); + + auto prog = optimize_onnx("matmulnbits_vm_test.onnx"); + + p.sort(); + prog.sort(); + EXPECT(p == prog); +} + +TEST_CASE(matmulnbits_bmm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {2, 3, 8}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::uint8_type, {2, 1, 8}}); + auto scales = mm->add_parameter("scales", migraphx::shape{migraphx::shape::float_type, {2}}); + + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), scales); + scales = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scales); + scales = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 16}}}), + scales); + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), scales); + scales = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}), scales); + + auto zp = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::uint8_type, {1}}, {8}}); + zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 8}}}), zp); + + b = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), b); + b = mm->add_instruction(migraphx::make_op("unpack_int4"), b); + b = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}), b); + b = mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scales, zp); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 2}}}), b); + mm->add_instruction(migraphx::make_op("dot"), a, b); + + auto prog = optimize_onnx("matmulnbits_bmm_test.onnx"); + + p.sort(); + prog.sort(); + EXPECT(p == prog); +} diff --git a/test/onnx/verify/matmulnbits_tests.cpp b/test/onnx/verify/matmulnbits_tests.cpp new file mode 100644 index 00000000000..4fa4c22768c --- /dev/null +++ b/test/onnx/verify/matmulnbits_tests.cpp @@ -0,0 +1,187 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "migraphx/argument.hpp" +#include "migraphx/module.hpp" +#include +#include +#include +#include +#include + +TEST_CASE(matmulnbits_mm_test) +{ + auto p = optimize_onnx("matmulnbits_mm_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + auto a_shape = migraphx::shape{migraphx::shape::float_type, {2, 16}}; + std::vector a(a_shape.elements()); + std::iota(a.begin(), a.end(), 0); + pm["a"] = migraphx::argument(a_shape, a.data()); + + auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {4, 1, 8}}; + std::vector b{0x2, 0xe3, 0xc7, 0x89, 0xbd, 0xbe, 0x50, 0x41, 0xe9, 0xb4, 0xd4, + 0x54, 0xc6, 0xb2, 0xfa, 0x27, 0x14, 0x3d, 0xbb, 0xe7, 0xa5, 0x0, + 0x52, 0x28, 0xc1, 0xd9, 0x1f, 0x33, 0x16, 0x1e, 0x8b, 0x3c}; + pm["b"] = migraphx::argument(b_shape, b.data()); + + auto scales_shape = migraphx::shape{migraphx::shape::float_type, {4}}; + std::vector scales{1, 2, 3, 4}; + pm["scales"] = migraphx::argument(scales_shape, scales.data()); + + auto zp_shape = migraphx::shape{migraphx::shape::uint8_type, {4}}; + std::vector zp{0x08, 0x09, 0x0a, 0x0b}; + pm["zp"] = migraphx::argument{zp_shape, zp.data()}; + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold{ + -111.0f, -290.0f, -1692.0f, -1960.0f, -335.0f, -770.0f, -4764.0f, -5992.0f}; + + EXPECT(result.get_shape().lens() == std::vector{2, 4}); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(matmulnbits_mm2_test) +{ + auto p = optimize_onnx("matmulnbits_mm2_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + auto a_shape = migraphx::shape{migraphx::shape::float_type, {2, 33}}; + std::vector a{ + 0.15541f, 0.24434f, 0.66716f, 0.13632f, 0.76915f, 0.21328f, 0.17331f, 0.93251f, 0.14816f, + 0.08181f, 0.54035f, 0.86664f, 0.92605f, 0.89766f, 0.02441f, 0.33504f, 0.60488f, 0.25918f, + 0.64644f, 0.98881f, 0.27669f, 0.94888f, 0.21201f, 0.33377f, 0.95608f, 0.40923f, 0.66899f, + 0.58904f, 0.41560f, 0.87399f, 0.74596f, 0.10849f, 0.94527f, 0.88573f, 0.66875f, 0.57536f, + 0.81454f, 0.15699f, 0.15464f, 0.17399f, 0.08090f, 0.99368f, 0.45535f, 0.92528f, 0.91968f, + 0.76970f, 0.59638f, 0.23635f, 0.54877f, 0.96025f, 0.48969f, 0.55297f, 0.52498f, 0.29102f, + 0.01359f, 0.77372f, 0.81897f, 0.03003f, 0.00822f, 0.55477f, 0.54635f, 0.91918f, 0.76486f, + 0.73698f, 0.29821f, 0.41801f}; + pm["a"] = migraphx::argument(a_shape, a.data()); + + auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {2, 3, 8}}; + std::vector b{0x18, 0x9, 0x8b, 0xe1, 0xfb, 0x94, 0x11, 0x56, 0x4e, 0xac, 0xd3, 0x4b, + 0xf7, 0x8e, 0x54, 0xef, 0x0b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x6e, 0xb7, 0x20, 0x4f, 0xa7, 0x82, 0x83, 0xbf, 0x20, 0xde, 0xa4, 0xf, + 0x72, 0x81, 0x8, 0x83, 0x0a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; + pm["b"] = migraphx::argument(b_shape, b.data()); + + auto scales_shape = migraphx::shape{migraphx::shape::float_type, {6}}; + std::vector scales{0.29033, 0.80435, 2.60200, 2.39623, 1.40796, 2.38139}; + pm["scales"] = migraphx::argument(scales_shape, scales.data()); + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold{18.54672f, -62.38305f, 4.978874f, -31.228657f}; + + EXPECT(result.get_shape().lens() == std::vector{2, 2}); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(matmulnbits_vm_test) +{ + auto p = optimize_onnx("matmulnbits_vm_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + auto a_shape = migraphx::shape{migraphx::shape::float_type, {20}}; + std::vector a{0.10266f, 0.12772f, 0.10865f, 0.66181f, 0.49644f, 0.30307f, 0.11225f, + 0.65619f, 0.06290f, 0.29208f, 0.63246f, 0.22758f, 0.99302f, 0.09735f, + 0.68126f, 0.93334f, 0.90533f, 0.31082f, 0.58161f, 0.61385f}; + pm["a"] = migraphx::argument(a_shape, a.data()); + + auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {3, 2, 8}}; + std::vector b{0xb7, 0x55, 0xfc, 0xc3, 0x66, 0xf9, 0x97, 0x83, 0xdd, 0x79, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0xcb, 0x52, 0xaf, 0x1d, 0x85, 0xbb, 0x64, 0x60, + 0x23, 0x42, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x38, 0xc6, 0xf7, 0x7a, + 0x68, 0xb1, 0x5, 0xc3, 0x37, 0xbb, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; + pm["b"] = migraphx::argument(b_shape, b.data()); + + auto scales_shape = migraphx::shape{migraphx::shape::float_type, {6}}; + std::vector scales{3.74611f, 0.29444f, 0.29047f, 0.55739f, 3.94635f, 2.86177f}; + pm["scales"] = migraphx::argument(scales_shape, scales.data()); + + auto zp_shape = migraphx::shape{migraphx::shape::uint8_type, {3}}; + std::vector zp{0x43, 0x28, 0x65}; + pm["zp"] = migraphx::argument{zp_shape, zp.data()}; + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold{131.22989f, -1.9659958f, 75.00621f}; + + EXPECT(result.get_shape().lens() == std::vector{3}); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(matmulnbits_bmm_test) +{ + auto p = optimize_onnx("matmulnbits_bmm_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + auto a_shape = migraphx::shape{migraphx::shape::float_type, {2, 3, 8}}; + std::vector a{0.01602f, 0.41420f, 0.97385f, 0.31764f, 0.40434f, 0.46265f, 0.93490f, + 0.16076f, 0.62340f, 0.39614f, 0.45347f, 0.98619f, 0.65113f, 0.56039f, + 0.33137f, 0.51959f, 0.70136f, 0.73935f, 0.95997f, 0.25623f, 0.26716f, + 0.27764f, 0.52128f, 0.55242f, 0.31295f, 0.54679f, 0.43674f, 0.21178f, + 0.99311f, 0.86172f, 0.10848f, 0.34330f, 0.36977f, 0.00948f, 0.93841f, + 0.88137f, 0.31069f, 0.39034f, 0.22825f, 0.29626f, 0.22664f, 0.51612f, + 0.39870f, 0.73411f, 0.07540f, 0.36283f, 0.62662f, 0.49075f}; + pm["a"] = migraphx::argument(a_shape, a.data()); + + auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {2, 1, 8}}; + std::vector b{ + 0xed, 0xf8, 0xa0, 0xac, 0x0, 0x0, 0x0, 0x0, 0x34, 0xf7, 0x42, 0x1f, 0x0, 0x0, 0x0, 0x0}; + pm["b"] = migraphx::argument(b_shape, b.data()); + + auto scales_shape = migraphx::shape{migraphx::shape::float_type, {2}}; + std::vector scales{1.43507, 1.28074}; + pm["scales"] = migraphx::argument(scales_shape, scales.data()); + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold{ + 9.386047f, + 0.32900935f, + 15.317321f, + -7.0316725f, + 16.28011f, + -11.014428f, + 1.7608745f, + -17.91667f, + 11.302611f, + -0.2521392f, + 18.625961f, + 0.38458022f, + }; + + EXPECT(result.get_shape().lens() == std::vector{2, 3, 2}); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +}