diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 65b48a3009e72..bc6423e026b78 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -229,3 +229,92 @@ MlasQ8Q4GemmBatch( const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool ); + + +//////////////////////////////////////////////////////////// +// Blockwise quantization and dequantization where quantization +// parameters are packed into seperate buffers. +// + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantization parameter matrix [meta_rows, meta_cols] +*/ +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + + +/** + * @brief Blockwise 4 bits quantization, resulting elements and quantization + * parameters (scales, zero points) are packed into seperate matrices + * all in column major layout for faster access during subsequent matrix + * multiplication. + * + * @tparam ElementT type of the input matrix element, usually floating point + * + * @param dst points to the quantized matrix, shape [rows, columns] column major + * @param scales points to the scales matrix, column major + * @param zero_points points to the zero_points matrix, column major + * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param leading_dimension + * @param thread_pool +*/ +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int32_t block_size, + bool columnwise, + int32_t rows, + int32_t columns, + int32_t leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +/** + * @brief Blockwise 4 bits dequantization, quantized elements and quantization + * parameters (scales, zero points) are from seperate matrices packed + * in column major layout. Output is a floating point matrix in column + * major layout for faster access during subsequent matrix multiplication. + * + * @tparam ElementT type of the dequantized matrix element, usually floating point + * @param dst points to dequantized matrix shape [rows, columns] column major + * @param src points to quantized matrix, column major + * @param scales points to quantization scales, column major + * @param zero_points points to quantization zero points, column major + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param thread_pool +*/ +template +void +MlasDequantizeBlockwise( + ElementT* dst, + const uint8_t* src, + const ElementT* scales, + const uint8_t* zero_points, + int32_t block_size, + bool columnwise, + int32_t rows, + int32_t columns, + MLAS_THREADPOOL* thread_pool + ); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index b6ac4a1ca1d6c..38a17731f13a3 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1069,6 +1069,23 @@ MlasTrySimpleParallel( const std::function& Work ); + +/** + * @brief Distribute many iterations of work over a thread pool if supported. + * This function is for small workloads in non-performance critical situation. + * + * @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP + * @param Iterations [IN] Total number of iterations + * @param Work [IN] Logic for computing a range of iterations [begin, end) + */ +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work + ); + + inline ptrdiff_t MlasGetMaximumThreadCount( diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 85c0d13006126..edb753dcf5cb2 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -294,3 +294,584 @@ MlasQ4GemmUnPackB( return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); } } + + + +/*************************************************************** + * The quantization format that pack data and quantization + * parameters into separate buffers. + */ + + +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + + +template +struct BitsTraits { + static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); + + static constexpr int kBits = qbits; + static constexpr int kMax = (1 << qbits) - 1; + static constexpr int kMid = 1 << (qbits - 1); + static constexpr float kMaxFp = static_cast(kMax); + + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); +}; + + +/** + * @brief Rectify min/max from a set of weights, and convert to scale and zero point + * for quantization + * @tparam ScaleT type of scale, usually floating point of various bits + * @tparam qbits number of int bits used for zero point value + * @param[in] min + * @param[in] max + * @param[out] scale + * @param[out] zp + */ +template +MLAS_FORCEINLINE +void +range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) +{ + constexpr int zp_max = BitsTraits::kMax; + constexpr float zp_max_fp = BitsTraits::kMaxFp; + + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + float scale_f = (max - min) / zp_max; + + float zero_point_fp = min; + if (scale_f != 0.0f) { + zero_point_fp = 0.f - min / scale_f; + } + + if (zero_point_fp < 0.0f) { + zp = 0; + } else if (zero_point_fp > zp_max_fp) { + zp = zp_max; + } else { + zp = (uint8_t)roundf(zero_point_fp); + } + scale = static_cast(scale_f); +} + +template +MLAS_FORCEINLINE +void +range2scale(float min, float max, ScaleT& scale) +{ + constexpr int mid_v = BitsTraits::kMid; + constexpr float mid_fp = static_cast(-mid_v); + + max = fabsf(max) > fabsf(min) ? max : min; + + scale = static_cast(max / mid_fp); +}; + + +/** + * @brief Blockwise quantization methods + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlockwiseQuantizer { + // To support other qbits, need to add bit packing code for + // storing to dst and zero points + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + + static void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) { + meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + meta_cols = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; + } + + /** + * @brief Quantized a Matrix shape [rows, columns], resulting quantized + * and packed data are stored in column major (transposed) + * @param[out] dst pointer to the quantized weights, column major: [columns, rows] + * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] + * @param[out] zero_points pointer to the zero points, same shape as scale + * @param[in] src pointer to the source matrix, row major: [rows, columns] + * @param rows + * @param columns + * @param leadingDimension stride of the source matrix, i.e. distance from one row to the next + */ + static void quantizeAndTranspose( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int32_t rows, + int32_t columns, + int32_t leadingDimension, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + uint8_t zp_bytes[BitsTraits::kPackSize]{0,0}; + + const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + const int32_t r = r_blk_idx * ThreadBlk::kRow; + const int32_t c = c_blk_idx * ThreadBlk::kColumn; + + const int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + const int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + const int meta_row = r / QuantBlk::kRow; + const int meta_col = c / QuantBlk::kColumn; + + if constexpr (Columnwise) { + static_assert(ThreadBlk::kColumn == 1, "Internal computation error!"); + int row_idx = r; + for (int kpack = 0; kpack < BitsTraits::kPackSize && row_idx < r_end; + kpack++) { + float min = std::numeric_limits::max(); + float max = -min; + for (int idx = 0; idx < QuantBlk::kRow && row_idx < r_end; + idx++, row_idx++) { + const float v = static_cast(src[row_idx * leadingDimension + c]); + if (v < min) min = v; + if (v > max) max = v; + } + + // store scale and zero point at matrix position + const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; + if (zero_points == nullptr) { + range2scale(min, max, scales[meta_idx]); + } else { + range2scalezp(min, max, scales[meta_idx], + zp_bytes[kpack]); + } + } + + } else { // Row-wise + static_assert(ThreadBlk::kRow == BitsTraits::kPackSize, + "Internal computation error!"); + for (int32_t i = r; i < r_end; ++i) { + float min = std::numeric_limits::max(); + float max = -min; + for (int32_t j = c; j < c_end; ++j) { + const float v = static_cast(src[i * leadingDimension + j]); + if (v < min) min = v; + if (v > max) max = v; + } + // store scale and zero point at matrix position (i/QuantBlk::kRow, + // j/QuantBlk::kColumn) it's packed as column major, so the index is (column + // * row_blks + row) over all it's (j/QuantBlk::kColumn) * row_blks + + // (i/QuantBlk::kRow) + const int32_t meta_idx = meta_col * row_blks + (i / QuantBlk::kRow); + if (zero_points == nullptr) { + range2scale(min, max, scales[meta_idx]); + } else { + range2scalezp(min, max, scales[meta_idx], + zp_bytes[i-r]); + } + } + } + + // !! 4b specific code as we need to pack 2 4b numbers into one byte + if (zero_points != nullptr) { + const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; + zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_c = j / QuantBlk::kColumn; + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_r = i / QuantBlk::kRow; + const float scale = static_cast(scales[meta_c * row_blks + meta_r]); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + const int zp_pair = + (zero_points == nullptr) + ? 0x88 + : zero_points[meta_c * ((row_blks + 1) / 2) + meta_r / 2]; + const int zp = (meta_r & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = + (uint8_t)std::min(BitsTraits::kMaxFp, + std::max(0.0f, roundf(v0 * reciprocal_scale + zp))); + + uint8_t vi1 = 0; + if (i + 1 < r_end) { + float reciprocal_scale1 = reciprocal_scale; + int zp1 = zp; + if constexpr (QuantBlk::kRow == 1) { + const float scale1 = + static_cast(scales[meta_c * row_blks + meta_r + 1]); + reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + zp1 = (zp_pair >> 4) & 0xf; + } + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::min( + BitsTraits::kMaxFp, + std::max(0.0f, roundf(v1 * reciprocal_scale1 + zp1))); + } + + // !! 4b specific code + dst[j * ((rows + 1) / 2) + i / 2] = (vi0 & 0xf) | (vi1 << 4); + } + } + }); + } + + /** + * @brief Dequantize a column major quantized matrix, and store the result in a column major + * matrix for use in GEMM + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] weights pointer to the quantized weights, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major: + * [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] + * @param[in] zero_points pointer to the zero points of quantized blocks, same shape as + * scales + * @param[in] rows + * @param[in] columns + */ + static void dequantize( + ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int32_t rows, + int32_t columns, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_col = j / QuantBlk::kColumn; + + // !! 4b specific code + // the whole loop is 4b specific due to sub 8 bit packing + // and unpacking. We can potentially make this qbits generic + // by wraping the packing/unpacking code like cutlass::Array + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_row = i / QuantBlk::kRow; + + const float scale0 = + static_cast(scales[meta_col * row_blks + meta_row]); + + const int zp_pair = + (zero_points == nullptr) + ? 0x88 + : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; + const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const uint8_t vi0 = weights[j * ((rows + 1) / 2) + i / 2] & 0xf; + const float v0 = (static_cast(vi0) - zp0) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + float scale1 = scale0; + int zp1 = zp0; + if constexpr (QuantBlk::kRow == 1) { + scale1 = + static_cast(scales[meta_col * row_blks + meta_row + 1]); + zp1 = (zp_pair >> 4) & 0xf; + } + const uint8_t vi1 = weights[j * ((rows + 1) / 2) + i / 2] >> 4; + const float v1 = (static_cast(vi1) - zp1) * scale1; + dst[j * rows + (i + 1)] = static_cast(v1); + } + } + } + }); + } +}; + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ) +{ + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape( + rows, columns, meta_rows, meta_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + default: + meta_rows = 0; + meta_cols = 0; + break; + } +} + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + T* scales, + uint8_t* zero_points, + const T* src, + int32_t block_size, + bool columnwise, + int32_t rows, + int32_t columns, + int32_t leading_dimension, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int32_t block_size, + bool columnwise, + int32_t rows, + int32_t columns, + int32_t leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +template +void +MlasDequantizeBlockwise( + T* dst, + const uint8_t* src, + const T* scales, + const uint8_t* zero_points, + int32_t block_size, + bool columnwise, + int32_t rows, + int32_t columns, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 32: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 64: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 128: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + case 256: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int32_t block_size, + bool columnwise, + int32_t rows, + int32_t columns, + MLAS_THREADPOOL* thread_pool + ); diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp index ecdc5250ebf0e..dc5daf998d3be 100644 --- a/onnxruntime/core/mlas/lib/threading.cpp +++ b/onnxruntime/core/mlas/lib/threading.cpp @@ -93,3 +93,41 @@ MlasTrySimpleParallel( MLAS_THREADPOOL::TrySimpleParallelFor(ThreadPool, Iterations, Work); #endif } + + +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work) +{ + // + // Execute the routine directly if only one iteration is specified. + // + if (Iterations == 1) { + Work(0); + return; + } + +#if defined(BUILD_MLAS_NO_ONNXRUNTIME) + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + // + // Fallback to OpenMP or a serialized implementation. + // + + // + // Execute the routine for the specified number of iterations. + // + for (ptrdiff_t tid = 0; tid < Iterations; tid++) { + Work(tid); + } +#else + // + // Schedule the threaded iterations using the thread pool object. + // + + MLAS_THREADPOOL::TryBatchParallelFor(ThreadPool, Iterations, Work, 0); +#endif + +} \ No newline at end of file diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp new file mode 100644 index 0000000000000..6f77127f49afd --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -0,0 +1,194 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_blockq4.cpp + +Abstract: + + Tests for MLAS blockwise int4 quantization and dequantization code. + +--*/ + +#ifndef ORT_MINIMAL_BUILD + +#include "test_util.h" +#include "mlas_q4.h" + + +class MlasBlockwiseQdqTest : public MlasTestBase { + private: + MatrixGuardBuffer FpBuf; + MatrixGuardBuffer FpBuf2; + MatrixGuardBuffer InputElements; + MatrixGuardBuffer InputScales; + MatrixGuardBuffer InputOffsets; + MatrixGuardBuffer OutputElements; + MatrixGuardBuffer OutputScales; + MatrixGuardBuffer OutputOffsets; + + void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { + float* dequant_buf = FpBuf.GetBuffer(rows * columns, true); + float* transposed = FpBuf2.GetBuffer(rows * columns, true); + + MLAS_THREADPOOL* threadpool_ptr = GetMlasThreadPool(); + + int meta_rows; + int meta_cols; + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + + uint8_t* elements = InputElements.GetBuffer(((rows + 1) / 2) * columns, true); + + int v = 7; + for (size_t c = 0; c < columns; c++) { + for (size_t r = 0; r < rows; r += 2) { + size_t idx = c * ((rows + 1) / 2) + r / 2; + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + elements[idx] = (v1 << 4) | v0; + } + } + + float* scales = InputScales.GetBuffer(meta_rows * meta_cols); + uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); + if (zp) { + for (size_t c = 0; c < meta_cols; c++) { + for (size_t r = 0; r < meta_rows; r += 2) { + size_t idx = c * ((meta_rows + 1) / 2) + r / 2; + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < meta_rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + zp[idx] = (v1 << 4) | v0; + } + } + } + + + MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, columnwise, rows, columns, threadpool_ptr); + + MlasTranspose(dequant_buf, transposed, columns, rows); + + uint8_t* o_elements = OutputElements.GetBuffer(((rows + 1) / 2) * columns); + float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); + uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols); + + + MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, columnwise, rows, columns, columns, threadpool_ptr); + + for (size_t c = 0; c < columns; c++) { + for (size_t r = 0; r < rows; r+=2) { + size_t idx = c * ((rows + 1) / 2) + r / 2; + ASSERT_EQ(o_elements[idx], elements[idx]) << ", index=" << idx << ", [" << rows << "x" + << columns << "] block: " << block_size; + } + } + + for (size_t c = 0; c < meta_cols; c++) { + for (size_t r = 0; r < meta_rows; r++) { + size_t idx = c * meta_rows + r; + ASSERT_EQ(o_scales[idx], scales[idx]) << ", index=" << idx << ", [" << rows << "x" + << columns << "] block: " << block_size; + } + } + + if (symmetric) return; + for (size_t c = 0; c < meta_cols; c++) { + for (size_t r = 0; r < meta_rows; r += 2) { + size_t idx = c * ((meta_rows + 1) / 2) + r / 2; + ASSERT_EQ(o_zp[idx], zp[idx]) << ", index=" << idx << ", [" << rows << "x" + << columns << "] block: " << block_size; + } + } + + } + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("BlockQ4"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + Test(20, 1, 32, true, false); + Test(20, 1, 32, true, true); + Test(1, 20, 32, false, false); + Test(1, 20, 32, false, true); + Test(52, 1, 32, true, false); + Test(52, 1, 32, true, true); + Test(1, 52, 32, false, false); + Test(1, 52, 32, false, true); + Test(20, 3, 32, true, false); + Test(20, 3, 32, true, true); + Test(3, 20, 32, false, false); + Test(3, 20, 32, false, true); + Test(52, 3, 32, true, false); + Test(52, 3, 32, true, true); + Test(3, 52, 32, false, false); + Test(3, 52, 32, false, true); + Test(52, 3, 64, true, false); + Test(52, 3, 64, true, true); + Test(3, 52, 64, false, false); + Test(3, 52, 64, false, true); + Test(32 * 9 + 17, 41, 32, true, false); + Test(32 * 9 + 17, 41, 32, true, true); + Test(41, 32 * 9 + 17, 32, false, false); + Test(41, 32 * 9 + 17, 32, false, true); + Test(32 * 9 + 17, 41, 64, true, false); + Test(32 * 9 + 17, 41, 64, true, true); + Test(41, 32 * 9 + 17, 64, false, false); + Test(41, 32 * 9 + 17, 64, false, true); + Test(32 * 15 + 17, 63, 128, true, false); + Test(32 * 15 + 17, 63, 128, true, true); + Test(63, 32 * 15 + 17, 128, false, false); + Test(63, 32 * 15 + 17, 128, false, true); + + Test(256, 256, 32, true, false); + Test(256, 256, 32, true, true); + Test(256, 256, 32, false, false); + Test(256, 256, 32, false, true); + } + + MlasBlockwiseQdqTest() = default; +}; + +template <> +MlasBlockwiseQdqTest* MlasTestFixture::mlas_tester(nullptr); + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // ORT_MINIMAL_BUILD