diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h index 6b1c883f96041..d49484a072be1 100644 --- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -69,8 +69,7 @@ template < typename ScaleElementT, typename Layout, typename QuantBlocking> -inline -void sm80_prepack_quant_scales_ref( +inline void sm80_prepack_quant_scales_ref( int rows, int columns, const MatrixRef& tensor_scale, @@ -130,6 +129,77 @@ void sm80_prepack_quant_scales_ref( } } +template +inline void sm80_expand_prepack_quant_offsets_ref( + int rows, + int columns, + MatrixRef tensor_offset, + MatrixRef tensor_offset_prepacked) { + const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, + "Unexpected tensor_offset_prepacked shape (", + tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], + ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); + ORT_ENFORCE(tensor_offset.shape() == zp_shape, + "Unexpected tensor_offset shape (", + tensor_offset.shape()[0], ",", tensor_offset.shape()[1], + ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")"); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (QuantBlocking::kRow != 1) { + return; + } + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + if (tensor_offset_prepacked.good()) { + for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; + } + } + } + } +} + template inline void sm80_prepack_quant_offsets_ref( diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h index 4db2a6340ed75..4cfb074e7df7d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -14,8 +14,11 @@ #pragma once +#include + #include "core/util/matrix_layout.h" #include "core/common/common.h" +#include "core/mickey/blk_q4/f16_prepack_sm80.h" #include "test/cuda_host/blkq4_fp16_quant_sm80.h" namespace onnxruntime { @@ -24,6 +27,157 @@ namespace test { Status sm80_supported(); +/** + * @brief Generate a set of quantized weights, scales and offsets + * and dequantized weights for testing quantization and + * dequantization. All outputs are column major layout. + * + * @tparam ElementT The type of the dequantized weights. + * @tparam block_size The block size of the quantization. + * @tparam col_blocking Whether to use column blocking (all elements of + * a block comes from a single column) or row blocking + * @tparam has_offsets Whether to generate offsets. + * + * @param[in] rows The number of rows of the weight matrix. + * @param[in] columns The number of columns of the weight matrix. + * @param[out] dequants The dequantized weights, column major layout. + * @param[out] q_weights The quantized weights, column major layout. + * @param[out] q_scales The scales, column major layout. + * @param[out] q_zp The zero points, column major layout. + */ +template +inline +void blkq4_weights_gen( + int rows, int columns, + std::vector& dequants, + std::vector& q_weights, + std::vector& q_scales, + std::vector& q_zp) { + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + + static_assert(std::is_same::value); + static_assert(std::is_same::value); + static_assert(std::is_same::value); + + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution dis(0, 8192); + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + q_weights.resize(q_weight_shape.product()); + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + int v = 7; + for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { + for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); + } + } + + q_scales.resize(meta_shape.product()); + for (size_t i = 0; i < q_scales.size(); i++) { + uint32_t v = dis(gen); + uint32_t m = (v % 63) + 1; + uint32_t e = (v >> 6) % 4; + q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); + } + MatrixRef tensor_scale( + q_scales, meta_shape); + + MatrixRef tensor_offset; + if constexpr(has_offsets) { + q_zp.resize(zp_shape.product()); + tensor_offset = MatrixRef( + q_zp, zp_shape); + for (int c = 0; c < zp_shape[1]; c++) { + for (int r = 0; r < zp_shape[0]; ++r) { + uint8_t v0 = dis(gen) % 16; + uint8_t v1 = 8; + if (r * 2 + 1 < meta_shape[0]) { + v1 = dis(gen) % 16; + } + tensor_offset.at(r, c) = static_cast(v0 | (v1 << 4)); + } + } + } + + dequants.resize(rows * columns); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + + // Dequantize weights and save into matrix B + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + auto weight_cord = make_Position(row / 2, col); + auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + uint8_t offset = 8; + if constexpr(has_offsets) { + if (scale_cord[0] % 2 == 0) { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f; + } else { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4; + } + } + int w = 0; + if (row % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + float scale = float(tensor_scale.at(scale_cord)); + float dequant = scale * float(w - offset); + tensor_dequant.at(row, col) = ElementT(dequant); + // Prints for help debugging in case of test failure + // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + } + } + +} + template < int block_size, bool column_wise_blocking, diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 60c9b16f4cf88..148055bd046e2 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -14,7 +14,6 @@ #include #include "core/framework/float16.h" -#include "core/mickey/blk_q4/f16_prepack_sm80.h" #include "core/mlas/inc/mlas_q4.h" #include "blkq4_fp16_gemm_sm80.h" @@ -24,15 +23,15 @@ namespace onnxruntime { namespace test { -template -void testPrepack(int rows, int columns, bool has_offset = true) { +template +void testPrepack(int rows, int columns) { using ElementT = MLFloat16; constexpr int block_size = 32; using Base = onnxruntime::cuda::BlockwiseQuantization< ElementT, block_size, 4, - ColumnMajorQuantBlocking>; + col_blocking>; using QuantBlocking = typename Base::QuantBlocking; using ElementW = typename Base::ElementW; @@ -40,147 +39,40 @@ void testPrepack(int rows, int columns, bool has_offset = true) { using ElementQOffset = typename Base::ElementQOffset; using LayoutQmeta = typename Base::LayoutQmeta; - unsigned int seed = 28571; // Replace with desired seed value - std::seed_seq seq{seed}; - std::mt19937 gen(seq); - std::uniform_int_distribution<> dis(0, 8192); - const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); - // - // For testing quantization and dequantization, it is not straight - // forward to avoid flaky tests due to rounding errors. The way we - // try to achieve this is to: - // 1. Generate a set of quantized weights, scales and offsets - // 2. Dequantize the weights - // 3. Quantize the dequantized weights - // 4. Compare the dequantied-and-then-quantized weights with - // the original quantized weights - // - // Random filling of the initial values are key to get this right. - // For weights, we must ensure each block gets a full range of - // values, i.e. must contain 0 and 15. And for scales, they must - // all be positive. - // + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + rows, columns, dequants, q_weights, q_scales, q_zp); - std::vector q_weights(q_weight_shape.product()); - MatrixRef tensor_q_weight( + // for quantization tool, the input is row major, all outputs are column major + MatrixRef tensor_q_weight( q_weights, make_Position(rows / 2, columns)); - int v = 7; - for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { - for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - - tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); - } - } - - std::vector q_scales(meta_shape.product()); - for (size_t i = 0; i < q_scales.size(); i++) { - q_scales[i] = ElementT(((dis(gen) % 127) + 1) / 32.0f); - } - MatrixRef tensor_scale( + MatrixRef tensor_scale( q_scales, meta_shape); - - std::vector q_zp(meta_shape.product()); - for (size_t i = 0; i < q_zp.size(); i++) { - q_zp[i] = dis(gen) % 16; - } - MatrixRef tensor_offset( - q_zp, meta_shape); - -#if 0 // debug - // Fill tensor_q_weight with the patterned data, easier to debug with print - int loop_val = 0; - int offset = 3; - for (int col_tile = 0; col_tile < tensor_q_weight.extent().column()/8; ++col_tile) { - for (int row_tile = 0; row_tile < tensor_q_weight.extent().row()/4; ++row_tile) { - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); - auto val = (loop_val + offset) % 256; - tensor_q_weight.at(weight_cord) = ElementW(val); - loop_val++; - if (loop_val == 256) { - loop_val = 0; - offset += 11; - } - } - } - } - } - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - int c = col * QuantBlocking::kColumn; - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - int r = row * QuantBlocking::kRow; - auto weight_cord = cutlass::make_Coord(r/2, c); - int w = 0; - if (r % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - tensor_scale.at({row, col}) = w; - tensor_offset.at({row, col}) = ElementQOffset(w); - } - } - - int fill_val = -512; - int factor = 1; - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); - fill_val++; - if (fill_val == 512) { - fill_val = -512; - factor += 1; - } - } + MatrixRef tensor_offset; + if constexpr(has_offset) { + tensor_offset = MatrixRef(q_zp, zp_shape); } -#endif // debug - - std::vector dequants(rows * columns); - MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); - - // Dequantize weights and save into matrix B for reference + // for quantization tool, the input is row major, test weight gen output is column major + std::vector dequants_transposed(dequants.size()); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + MatrixRef tensor_dequant_transposed(dequants_transposed, make_Position(rows, columns)); for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { - auto weight_cord = make_Position(row / 2, col); - auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); - const uint8_t offset = has_offset ? tensor_offset.at(scale_cord) : 8; - int w = 0; - if (row % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - float scale = float(tensor_scale.at(scale_cord)); - float dequant = scale * float(w - offset); - tensor_dequant.at(row, col) = ElementT(dequant); - // Prints for help debugging in case of test failure - // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + tensor_dequant_transposed.at(row, col) = tensor_dequant.at(row, col); } } int q_rows, q_cols; MlasBlockwiseQuantizedShape( - block_size, ColumnMajorQuantBlocking, rows, columns, q_rows, q_cols); + block_size, col_blocking, rows, columns, q_rows, q_cols); // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes EXPECT_EQ(q_rows, q_weight_shape[0]); EXPECT_EQ(q_cols, q_weight_shape[1]); @@ -194,19 +86,18 @@ void testPrepack(int rows, int columns, bool has_offset = true) { std::vector o_scales(meta_shape.product()); MatrixRef tensor_o_scales(o_scales, meta_shape); - std::vector o_zp(((meta_shape[0] + 1) / 2) * meta_shape[1], true); - MatrixRef tensor_o_zp( - o_zp, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); + std::vector o_zp(zp_shape.product()); + MatrixRef tensor_o_zp(o_zp, zp_shape); MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, - tensor_dequant.data().data(), block_size, - ColumnMajorQuantBlocking, rows, columns, columns, nullptr); + dequants_transposed.data(), block_size, + col_blocking, rows, columns, columns, nullptr); for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) << "quantized value mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } @@ -215,16 +106,17 @@ void testPrepack(int rows, int columns, bool has_offset = true) { for (int row = 0; row < meta_shape[0]; row += 2) { if (has_offset) { uint8_t pair01 = tensor_o_zp.at(row / 2, col); - EXPECT_EQ(tensor_offset.at(row + 0, col), pair01 & 0xf) + uint8_t expected_pair01 = tensor_offset.at(row / 2, col); + EXPECT_EQ(expected_pair01 & 0xf, pair01 & 0xf) << "quantized offset mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; if (row + 1 < meta_shape[0]) { - EXPECT_EQ(tensor_offset.at(row + 1, col), pair01 >> 4) + EXPECT_EQ(expected_pair01 >> 4, pair01 >> 4) << "quantized offset mismatch at [" << row + 1 << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } @@ -232,22 +124,22 @@ void testPrepack(int rows, int columns, bool has_offset = true) { EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) << "quantized scale mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; if (row + 1 < meta_shape[0]) { EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) << "quantized scale mismatch at [" << row + 1 << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } } // - // Now we just setup fp16 weights tensor_dequant, quantized weights tensor_q_weight, - // quantization scale tensor_scale and quantization offset tensor_offset. The above - // testing just make sure our test setup is consistent with quantization tool output. + // Now we just setup quantized weights tensor_q_weight, quantization scale tensor_scale + // and quantization offset tensor_offset. The above tests just make sure our setup is + // consistent with quantization tool output. // // Next we test the prepack code // @@ -267,18 +159,23 @@ void testPrepack(int rows, int columns, bool has_offset = true) { EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) << "prepacked weights mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } std::vector packed_scales_ref(meta_shape.product()); MatrixRef tensor_packed_s_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_scales_ref, meta_shape) - : tensor_scale; - if (Base::ShouldRearrangeMeta) { + make_MatrixRef(packed_scales_ref, meta_shape); + if constexpr(Base::ShouldRearrangeMeta) { onnxruntime::test::sm80_prepack_quant_scales_ref( rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); + } else { + for (int col = 0; col < tensor_packed_s_ref.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s_ref.shape()[0]; ++row) { + tensor_packed_s_ref.at(row, col) = tensor_scale.at(row, col); + } + } } std::vector packed_scales(meta_shape.product()); @@ -291,7 +188,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) << "prepacked scales mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } @@ -299,11 +196,20 @@ void testPrepack(int rows, int columns, bool has_offset = true) { if (has_offset) { std::vector packed_zp_ref(meta_shape.product()); MatrixRef tensor_packed_zp_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_zp_ref, meta_shape) - : tensor_offset; - if (Base::ShouldRearrangeMeta) { - onnxruntime::test::sm80_prepack_quant_offsets_ref( + make_MatrixRef(packed_zp_ref, meta_shape); + if constexpr(Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_expand_prepack_quant_offsets_ref( rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); + } else { + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_packed_zp_ref.at(row, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_packed_zp_ref.at(row + 1, col) = pair01 >> 4; + } + } + } } std::vector packed_zp(meta_shape.product()); @@ -316,7 +222,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) << "prepacked offsets mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } @@ -332,9 +238,9 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) { } testPrepack(32, 32); - testPrepack(32, 32, false); + testPrepack(32, 32); testPrepack(32, 32); - testPrepack(32, 32, false); + testPrepack(32, 32); testPrepack(32, 64); testPrepack(32, 128); testPrepack(32, 256); @@ -342,9 +248,9 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) { testPrepack(128, 32); testPrepack(256, 32); testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); testPrepack(32, 64); testPrepack(32, 128); testPrepack(32, 256); @@ -352,9 +258,9 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) { testPrepack(128, 32); testPrepack(256, 32); testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); } TEST(BlkQ4_GEMM, Sm80Test) { diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index 733e88da9fc89..69c929d446ce4 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -11,6 +11,10 @@ * well with gtest headers. */ +#include +#include +#include + #include "core/mickey/blk_q4/f16_gemm_sm80.h" #include "cutlass/util/host_tensor.h" @@ -149,6 +153,10 @@ template< bool small_m, bool has_offsets> void run_blkq4_gemm(int m, int n, int k) { + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution<> dis(0, 8192); using ElementDequant = cutlass::half_t; using QuantBlocking = @@ -173,23 +181,38 @@ void run_blkq4_gemm(int m, int n, int k) { using LayoutInputQScale = typename GemmRunner::LayoutInputQScale; const cutlass::gemm::GemmCoord problem_size = {m, n, k}; + const auto q_weight_shape = cutlass::make_Coord(problem_size.k()/2, problem_size.n()); + const auto meta_shape = cutlass::make_Coord(problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn); + + // + // Generate quantized and dequantizeed input matrix B [K, N] + // + static_assert(std::is_same::value); + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + problem_size.k(), problem_size.n(), dequants, q_weights, q_scales, q_zp); + + using PrepackT = onnxruntime::cuda::BlockwiseQuantization< + ElementDequant, + block_size, + 4, + column_wise_blocking>; + + std::vector packed_w(q_weight_shape.product()); + PrepackT::prepack_weights(problem_size.k(), problem_size.n(), q_weights, packed_w); + std::vector packed_scales(meta_shape.product()); + PrepackT::prepack_quant_scales(problem_size.k(), problem_size.n(), q_scales, packed_scales); + std::vector packed_zp; + if constexpr (has_offsets) { + packed_zp.resize(meta_shape.product()); + PrepackT::prepack_quant_offsets(problem_size.k(), problem_size.n(), q_zp, packed_zp); + } - // Initialize tensors using CUTLASS helper functions cutlass::HostTensor tensor_a( problem_size.mk()); // <- Create matrix A with dimensions M x K - - // Create weight matrix with dimensions K x N. - // Actual weight type is int4, we use ElementW = uint8 to avoid possible compilation - // troubles. Since the layout is column major, we are packing 2 weights in a column - // into one int8 - cutlass::HostTensor tensor_weight( - {problem_size.k()/2, problem_size.n()}); - // Create weight quantization scale and offset with dimensions K x N - cutlass::HostTensor tensor_scale( - {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); - cutlass::HostTensor tensor_offset( - {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); - cutlass::HostTensor tensor_c( problem_size.mn()); // <- Create matrix C with dimensions M x N cutlass::HostTensor tensor_d( @@ -203,14 +226,6 @@ void run_blkq4_gemm(int m, int n, int k) { ElementInputA(4), ElementInputA(-4), 2); // <- Fill matrix A on host with uniform-distribution random data - if constexpr (has_offsets) { - cutlass::reference::host::TensorFillRandomUniform( - tensor_offset.host_view(), - 1, - ElementQOffset(0), - ElementQOffset(15), - 0); // <- Fill weight offsets on host with uniform-distribution random data - } cutlass::reference::host::TensorFillRandomUniform( tensor_c.host_view(), 1, @@ -221,188 +236,52 @@ void run_blkq4_gemm(int m, int n, int k) { tensor_d.host_view()); // <- fill matrix D on host with zeros // - // For testing quantization and dequantization, it is not straight - // forward to avoid flaky tests due to rounding errors. The way we - // try to achieve this is to: - // 1. Generate a set of quantized weights, scales and offsets - // 2. Dequantize the weights - // 3. Quantize the dequantized weights - // 4. Compare the dequantied-and-then-quantized weights with - // the original quantized weights - // - // Random filling of the initial values are key to get this right. - // For weights, we must ensure each block gets a full range of - // values, i.e. must contain 0 and 15. And for scales, they must - // all be positive. + // Copy data from host to GPU... // + thrust::device_vector d_packed_w(packed_w); + cutlass::TensorRef ref_W( + reinterpret_cast(d_packed_w.data().get()), + LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2})); - int v = 7; - for (int c = 0; c < tensor_weight.extent()[1]; c++) { - for (int r = 0; r < tensor_weight.extent()[0]; ++r) { - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - - tensor_weight.at({r, c}) = ElementW((v1 << 4) | v0); - } - } - - for (int c = 0; c < tensor_scale.extent()[1]; c++) { - for (int r = 0; r < tensor_scale.extent()[0]; ++r) { - int f = (((c * v + r + v / 3 ) % 63) + 1); - v += 41; - int m = (c * v + r + v / 8 ) % 4; - tensor_scale.at({r, c}) = ElementQScale(static_cast(f) / static_cast(1 << (2 + m))); - } - } + thrust::device_vector d_packed_scales(packed_scales); + cutlass::TensorRef ref_scales( + d_packed_scales.data().get(), LayoutInputQScale::packed(meta_shape)); -// // Fill tensor_weight with the patterned data, so that we can use -// // print to make sure the layout matches after loaded to registers -// int loop_val = 0; -// int offset = 3; -// for (int col_tile = 0; col_tile < tensor_weight.extent().column()/8; ++col_tile) { -// for (int row_tile = 0; row_tile < tensor_weight.extent().row()/4; ++row_tile) { -// for (int col = 0; col < 8; ++col) { -// for (int row = 0; row < 4; ++row) { -// auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); -// auto val = (loop_val + offset) % 256; -// tensor_weight.at(weight_cord) = ElementW(val); -// loop_val++; -// if (loop_val == 256) { -// loop_val = 0; -// offset += 11; -// } -// } -// } -// } -// } -// for (int col = 0; col < tensor_scale.extent().column(); ++col){ -// int c = col * QuantBlocking::kColumn; -// for (int row = 0; row < tensor_scale.extent().row(); ++row){ -// int r = row * QuantBlocking::kRow; -// auto weight_cord = cutlass::make_Coord(r/2, c); -// int w = 0; -// if (r % 2 == 0) { -// w = int(tensor_weight.at(weight_cord) & 0x0f); -// } else { -// w = int(tensor_weight.at(weight_cord) >> 4); -// } -// tensor_scale.at({row, col}) = w; -// #ifdef USE_QUANT_OFFSET -// tensor_offset.at({row, col}) = ElementQOffset(w); -// #endif -// } -// } - - // int fill_val = -512; - // int factor = 1; - // for (int col = 0; col < tensor_scale.extent().column(); ++col){ - // for (int row = 0; row < tensor_scale.extent().row(); ++row){ - // tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); - // fill_val++; - // if (fill_val == 512) { - // fill_val = -512; - // factor += 1; - // } - // } - // } - - // std::cout << "Matrix Weight:\n" << tensor_weight.host_view() << "\n"; - - // Prepacking weight matrix and quantization meta data ... - - cutlass::HostTensor tensor_weight_prepacked( - cutlass::make_Coord(problem_size.k(), problem_size.n()/2)); - onnxruntime::test::sm80_prepack_weights_ref( - problem_size.k(), problem_size.n(), - make_ConstMatrixRef(tensor_weight), - make_MatrixRef(tensor_weight_prepacked)); - - cutlass::HostTensor tensor_scale_prepacked( - {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); - cutlass::HostTensor tensor_offset_prepacked( - {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); - - auto scale_ref = make_ConstMatrixRef(tensor_scale); - onnxruntime::test::sm80_prepack_quant_scales_ref( - problem_size.k(), problem_size.n(), scale_ref, - make_MatrixRef(tensor_scale_prepacked)); - if constexpr (has_offsets) { - auto offset_ref = make_ConstMatrixRef(tensor_offset); - onnxruntime::test::sm80_prepack_quant_offsets_ref( - problem_size.k(), problem_size.n(), offset_ref, - make_MatrixRef(tensor_offset_prepacked)); - } + thrust::device_vector d_packed_zp(packed_zp); + cutlass::TensorRef ref_zp( + d_packed_zp.data().get(), LayoutInputQScale::packed(meta_shape)); - // Copy data from host to GPU... tensor_a.sync_device(); - tensor_weight_prepacked.sync_device(); - tensor_scale_prepacked.sync_device(); - if constexpr (has_offsets) { - tensor_offset_prepacked.sync_device(); - } tensor_c.sync_device(); tensor_d.sync_device(); - cutlass::TensorRef ref_W( - reinterpret_cast(tensor_weight_prepacked.device_data()), - LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2})); // run GEMM cutlass::Status status; if constexpr (has_offsets){ status = GemmRunner::run( nullptr, problem_size, tensor_a.device_ref(), ref_W, - tensor_scale_prepacked.device_ref(), tensor_offset_prepacked.device_ref(), + ref_scales, ref_zp, tensor_c.device_ref(), tensor_d.device_ref()); } else { status = GemmRunner::run( nullptr, problem_size, tensor_a.device_ref(), ref_W, - tensor_scale_prepacked.device_ref(), + ref_scales, tensor_c.device_ref(), tensor_d.device_ref()); } ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); - // Preparing reference kernel arguments - // Dequantizing weights and running reference kernel - + // Running reference kernel using ElementInputB = ElementInputA; using LayoutInputB = cutlass::layout::ColumnMajor; - cutlass::HostTensor tensor_b( - problem_size.kn()); // <- Create dequantized matrix B with dimensions K x N + thrust::device_vector d_dequants(dequants); + cutlass::TensorRef ref_B( + d_dequants.data().get(), LayoutInputB::packed(problem_size.kn())); cutlass::HostTensor tensor_ref_d( problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from // reference kernel - // Dequantize weights and save into matrix B for reference - for (int col = 0; col < tensor_b.extent().column(); ++col){ - for (int row = 0; row < tensor_b.extent().row(); ++row) { - auto weight_cord = cutlass::make_Coord(row/2, col); - auto scale_cord = cutlass::make_Coord(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); - const uint8_t offset = has_offsets ? tensor_offset.at(scale_cord) : 8; - int w = 0; - if (row % 2 == 0) { - w = int(tensor_weight.at(weight_cord) & 0x0f) - offset; - } else { - w = int(tensor_weight.at(weight_cord) >> 4) - offset; - } - auto scale = tensor_scale.at(scale_cord); - tensor_b.at({row, col}) = scale * float(w); - } - } cutlass::reference::host::TensorFill( tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros - - tensor_b.sync_device(); tensor_ref_d.sync_device(); // Initialize alpha and beta for dot product computation @@ -416,7 +295,7 @@ void run_blkq4_gemm(int m, int n, int k) { problem_size, alpha, tensor_a.device_ref(), - tensor_b.device_ref(), + ref_B, beta, tensor_c.device_ref(), tensor_ref_d.device_ref());