diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 65b48a3009e72..f3bc2a2434ab3 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -229,3 +229,117 @@ MlasQ8Q4GemmBatch( const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool ); + + +//////////////////////////////////////////////////////////// +// Blockwise quantization and dequantization where quantization +// parameters are packed into separate buffers. +// + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantization parameter matrix [meta_rows, meta_cols] +*/ +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +/** + * @brief For quantization type , and + * matrix shape [rows, columns], compute the shape of the + * quantized matrix [q_rows, q_cols]. The quantized matrix + * is in column major layout, with bits packed on the column. + * + * @tparam T + * @param block_size + * @param columnwise + * @param rows + * @param columns + * @param q_rows + * @param q_cols +*/ +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + + +/** + * @brief Blockwise 4 bits quantization, resulting elements and quantization + * parameters (scales, zero points) are packed into separate matrices + * all in column major layout for faster access during subsequent matrix + * multiplication. + * + * @tparam ElementT type of the input matrix element, usually floating point + * + * @param dst points to the quantized matrix, shape [rows, columns] column major + * @param scales points to the scales matrix, column major + * @param zero_points points to the zero_points matrix, column major + * @param src points to the floating point matrix, to be quantized, row major shape [rows, columns] + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param leading_dimension + * @param thread_pool +*/ +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +/** + * @brief Blockwise 4 bits dequantization, quantized elements and quantization + * parameters (scales, zero points) are from separate matrices packed + * in column major layout. Output is a floating point matrix in column + * major layout for faster access during subsequent matrix multiplication. + * + * @tparam ElementT type of the dequantized matrix element, usually floating point + * @param dst points to dequantized matrix shape [rows, columns] column major + * @param src points to quantized matrix, column major + * @param scales points to quantization scales, column major + * @param zero_points points to quantization zero points, column major + * @param block_size size of the block to quantize, elements from the same block share the same scale and zero point + * @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row + * @param rows + * @param columns + * @param thread_pool +*/ +template +void +MlasDequantizeBlockwise( + ElementT* dst, + const uint8_t* src, + const ElementT* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index b6ac4a1ca1d6c..38a17731f13a3 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1069,6 +1069,23 @@ MlasTrySimpleParallel( const std::function& Work ); + +/** + * @brief Distribute many iterations of work over a thread pool if supported. + * This function is for small workloads in non-performance critical situation. + * + * @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP + * @param Iterations [IN] Total number of iterations + * @param Work [IN] Logic for computing a range of iterations [begin, end) + */ +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work + ); + + inline ptrdiff_t MlasGetMaximumThreadCount( diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 85c0d13006126..24a2212ba0714 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -294,3 +294,649 @@ MlasQ4GemmUnPackB( return MlasQ4GemmUnPackBImpl(FpData, PackedBuf, N, K, ldb); } } + + + +/*************************************************************** + * The quantization format that pack data and quantization + * parameters into separate buffers. + */ + + +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + + +template +struct BitsTraits { + static_assert(qbits <= 8, "Only BitsTraits are for small number of bits!"); + + static constexpr int kBits = qbits; + static constexpr int kMax = (1 << qbits) - 1; + static constexpr int kMid = 1 << (qbits - 1); + static constexpr float kMaxFp = static_cast(kMax); + + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); +}; + + +/** + * @brief Rectify min/max from a set of weights, and convert to scale and zero point + * for quantization + * @tparam ScaleT type of scale, usually floating point of various bits + * @tparam qbits number of int bits used for zero point value + * @param[in] min + * @param[in] max + * @param[out] scale + * @param[out] zp + */ +template +MLAS_FORCEINLINE +void +range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) +{ + constexpr int zp_max = BitsTraits::kMax; + constexpr float zp_max_fp = BitsTraits::kMaxFp; + + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + float scale_f = (max - min) / zp_max; + + float zero_point_fp = min; + if (scale_f != 0.0f) { + zero_point_fp = 0.f - min / scale_f; + } + + if (zero_point_fp < 0.0f) { + zp = 0; + } else if (zero_point_fp > zp_max_fp) { + zp = zp_max; + } else { + zp = (uint8_t)roundf(zero_point_fp); + } + scale = static_cast(scale_f); +} + +template +MLAS_FORCEINLINE +void +range2scale(float min, float max, ScaleT& scale) +{ + constexpr int mid_v = BitsTraits::kMid; + constexpr float mid_fp = static_cast(-mid_v); + + max = fabsf(max) > fabsf(min) ? max : min; + + scale = static_cast(max / mid_fp); +}; + + +/** + * @brief Blockwise quantization methods + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlockwiseQuantizer { + // To support other qbits, need to add bit packing code for + // storing to dst and zero points + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + + static + MLAS_FORCEINLINE + void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) + { + meta_rows = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + meta_cols = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; + } + + static + MLAS_FORCEINLINE + void quantizedShape(int rows, int columns, int& q_rows, int& q_cols) { + int meta_rows; + int meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + + // quantized matrix is stored in column major, packed by column + q_rows = (meta_rows * QuantBlk::kRow * qbits + 7) / 8; + q_cols = meta_cols * QuantBlk::kColumn; + } + + /** + * @brief Quantized a Matrix shape [rows, columns], resulting quantized + * and packed data are stored in column major (transposed) + * @param[out] dst pointer to the quantized weights, column major: [columns, rows] + * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] + * @param[out] zero_points pointer to the zero points, same shape as scale + * @param[in] src pointer to the source matrix, row major: [rows, columns] + * @param rows + * @param columns + * @param leadingDimension stride of the source matrix, i.e. distance from one row to the next + */ + static void quantizeAndTranspose( + uint8_t* dst, + ElementT* scales, + uint8_t* zero_points, + const ElementT* src, + int32_t rows, + int32_t columns, + int32_t leadingDimension, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + uint8_t zp_bytes[BitsTraits::kPackSize]; + std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + + const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + const int32_t r = r_blk_idx * ThreadBlk::kRow; + const int32_t c = c_blk_idx * ThreadBlk::kColumn; + + const int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + const int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + const int meta_row = r / QuantBlk::kRow; + const int meta_col = c / QuantBlk::kColumn; + + // compute scale and zero point + for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + + // scan a single block to extract range [min, max] + float min = std::numeric_limits::max(); + float max = -min; + const int row_start = r + kpack * QuantBlk::kRow; + const int row_end = std::min(row_start + QuantBlk::kRow, r_end); + for (int i = row_start; i < row_end; ++i) { + for (int j = c; j < c_end; ++j) { + const float v = static_cast(src[i * leadingDimension + j]); + if (v < min) min = v; + if (v > max) max = v; + } + } + + // store scale and zero point at quant parameter matrix position + if (row_start < row_end) { + const int32_t meta_idx = meta_col * row_blks + meta_row + kpack; + if (zero_points == nullptr) { + range2scale(min, max, scales[meta_idx]); + } else { + range2scalezp(min, max, scales[meta_idx], zp_bytes[kpack]); + } + } + } + + // !! 4b specific code as we need to pack 2 4b numbers into one byte + if (zero_points != nullptr) { + const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; + zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_c = j / QuantBlk::kColumn; + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_r = i / QuantBlk::kRow; + const float scale = static_cast(scales[meta_c * row_blks + meta_r]); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + const int8_t zp = zp_bytes[meta_r & 1]; + const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; + + const float v0 = static_cast(src[i * leadingDimension + j]); + const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), + 0.0f, BitsTraits::kMaxFp); + + uint8_t vi1 = (uint8_t)zp; + if (i + 1 < r_end) { + float reciprocal_scale1 = reciprocal_scale; + if constexpr (QuantBlk::kRow == 1) { + const float scale1 = + static_cast(scales[meta_c * row_blks + meta_r + 1]); + reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; + } + const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); + vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, + BitsTraits::kMaxFp); + } + + // !! 4b specific code + dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); + } + } + }); + } + + /** + * @brief Dequantize a column major quantized matrix, and store the result in a column major + * matrix for use in GEMM + * @param[out] dst pointer to the dequantized matrix, column major: [columns, rows] + * @param[in] weights pointer to the quantized weights, column major: [columns, rows] + * @param[in] scales pointer to the scales of quantized blocks, column major layout + * @param[in] zero_points pointer to the zero points of quantized blocks, packed column major + * scales + * @param[in] rows + * @param[in] columns + */ + static void dequantize( + ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int32_t rows, + int32_t columns, + MLAS_THREADPOOL* thread_pool) + { + // Thread partitioning + const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; + + const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + MlasTryBatchParallel( + thread_pool, total_thrd_blks, + [&](ptrdiff_t block_idx) { + int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); + int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); + + int32_t r = r_blk_idx * ThreadBlk::kRow; + int32_t c = c_blk_idx * ThreadBlk::kColumn; + + int32_t r_end = std::min(r + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c + ThreadBlk::kColumn, columns); + + for (int32_t j = c; j < c_end; ++j) { + const int32_t meta_col = j / QuantBlk::kColumn; + + // !! 4b specific code + // the whole loop is 4b specific due to sub 8 bit packing + // and unpacking. We can potentially make this qbits generic + // by wraping the packing/unpacking code like cutlass::Array + for (int32_t i = r; i < r_end; i += 2) { + const int32_t meta_row = i / QuantBlk::kRow; + + const float scale0 = + static_cast(scales[meta_col * row_blks + meta_row]); + + const int zp_pair = + (zero_points == nullptr) + ? 0x88 + : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; + const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); + + const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; + const float v0 = (static_cast(vi0) - zp0) * scale0; + + dst[j * rows + i] = static_cast(v0); + if ((i + 1) < r_end) { + float scale1 = scale0; + int zp1 = zp0; + if constexpr (QuantBlk::kRow == 1) { + scale1 = + static_cast(scales[meta_col * row_blks + meta_row + 1]); + zp1 = (zp_pair >> 4) & 0xf; + } + const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; + const float v1 = (static_cast(vi1) - zp1) * scale1; + dst[j * rows + (i + 1)] = static_cast(v1); + } + } + } + }); + } +}; + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ) +{ + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape( + rows, columns, meta_rows, meta_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } else { + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + meta_cols); + } + break; + } + default: + meta_rows = 0; + meta_cols = 0; + break; + } +} + + + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ) +{ + switch (block_size) { + case 16: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 32: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape( + rows, columns, q_rows, q_cols); + } + break; + } + case 64: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 128: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + case 256: { + if (columnwise) { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } else { + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + } + break; + } + default: + q_rows = 0; + q_cols = 0; + break; + } +} + + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + T* scales, + uint8_t* zero_points, + const T* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } else { + BlockwiseQuantizer::quantizeAndTranspose( + dst, scales, zero_points, src, rows, columns, leading_dimension, thread_pool); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + +template +void +MlasDequantizeBlockwise( + T* dst, + const uint8_t* src, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ) +{ + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 32: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 64: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } + break; + case 128: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + case 256: + if (columnwise) { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, rows, + columns, thread_pool); + } else { + BlockwiseQuantizer::dequantize(dst, src, scales, zero_points, + rows, columns, thread_pool); + } + break; + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } +} + +template +void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool + ); diff --git a/onnxruntime/core/mlas/lib/threading.cpp b/onnxruntime/core/mlas/lib/threading.cpp index ecdc5250ebf0e..dc5daf998d3be 100644 --- a/onnxruntime/core/mlas/lib/threading.cpp +++ b/onnxruntime/core/mlas/lib/threading.cpp @@ -93,3 +93,41 @@ MlasTrySimpleParallel( MLAS_THREADPOOL::TrySimpleParallelFor(ThreadPool, Iterations, Work); #endif } + + +void +MlasTryBatchParallel( + MLAS_THREADPOOL * ThreadPool, + const std::ptrdiff_t Iterations, + const std::function& Work) +{ + // + // Execute the routine directly if only one iteration is specified. + // + if (Iterations == 1) { + Work(0); + return; + } + +#if defined(BUILD_MLAS_NO_ONNXRUNTIME) + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + // + // Fallback to OpenMP or a serialized implementation. + // + + // + // Execute the routine for the specified number of iterations. + // + for (ptrdiff_t tid = 0; tid < Iterations; tid++) { + Work(tid); + } +#else + // + // Schedule the threaded iterations using the thread pool object. + // + + MLAS_THREADPOOL::TryBatchParallelFor(ThreadPool, Iterations, Work, 0); +#endif + +} \ No newline at end of file diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp new file mode 100644 index 0000000000000..6f06e0f2eead8 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -0,0 +1,208 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_blockq4.cpp + +Abstract: + + Tests for MLAS blockwise int4 quantization and dequantization code. + +--*/ + +#ifndef ORT_MINIMAL_BUILD + +#include "test_util.h" +#include "mlas_q4.h" + +class MlasBlockwiseQdqTest : public MlasTestBase { + private: + MatrixGuardBuffer FpBuf; + MatrixGuardBuffer FpBuf2; + MatrixGuardBuffer InputElements; + MatrixGuardBuffer InputScales; + MatrixGuardBuffer InputOffsets; + MatrixGuardBuffer OutputElements; + MatrixGuardBuffer OutputScales; + MatrixGuardBuffer OutputOffsets; + + void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { + float* dequant_buf = FpBuf.GetBuffer(rows * columns, true); + float* transposed = FpBuf2.GetBuffer(rows * columns, true); + + MLAS_THREADPOOL* threadpool_ptr = GetMlasThreadPool(); + + int meta_rows; + int meta_cols; + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + + int q_rows; + int q_cols; + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + + uint8_t* elements = InputElements.GetBuffer(q_rows * q_cols, true); + + int v = 7; + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += 2) { + int idx = c * q_rows + r / 2; + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + elements[idx] = (v1 << 4) | v0; + } + } + + float* scales = InputScales.GetBuffer(meta_rows * meta_cols); + uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); + if (zp) { + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += 2) { + int idx = c * ((meta_rows + 1) / 2) + r / 2; + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < meta_rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + zp[idx] = (v1 << 4) | v0; + } + } + } + + MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, columnwise, rows, columns, threadpool_ptr); + + MlasTranspose(dequant_buf, transposed, columns, rows); + + uint8_t* o_elements = OutputElements.GetBuffer(q_rows * q_cols, true); + float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); + uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); + + MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, columnwise, rows, columns, columns, threadpool_ptr); + + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += 2) { + int idx = c * q_rows + r / 2; + ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) + << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (r + 1 < rows) { + ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) + << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } + } + + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r++) { + int idx = c * meta_rows + r; + ASSERT_EQ(o_scales[idx], scales[idx]) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } + + if (symmetric) return; + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += 2) { + int idx = c * ((meta_rows + 1) / 2) + r / 2; + ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) + << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + if (r + 1 < meta_rows) { + ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) + << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; + } + } + } + } + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("BlockQ4"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + Test(20, 1, 32, true, false); + Test(20, 1, 32, true, true); + Test(1, 20, 32, false, false); + Test(1, 20, 32, false, true); + Test(52, 1, 32, true, false); + Test(52, 1, 32, true, true); + Test(1, 52, 32, false, false); + Test(1, 52, 32, false, true); + Test(20, 3, 32, true, false); + Test(20, 3, 32, true, true); + Test(3, 20, 32, false, false); + Test(3, 20, 32, false, true); + Test(52, 3, 32, true, false); + Test(52, 3, 32, true, true); + Test(3, 52, 32, false, false); + Test(3, 52, 32, false, true); + Test(52, 3, 64, true, false); + Test(52, 3, 64, true, true); + Test(3, 52, 64, false, false); + Test(3, 52, 64, false, true); + Test(32 * 9 + 17, 41, 32, true, false); + Test(32 * 9 + 17, 41, 32, true, true); + Test(41, 32 * 9 + 17, 32, false, false); + Test(41, 32 * 9 + 17, 32, false, true); + Test(32 * 9 + 17, 41, 64, true, false); + Test(32 * 9 + 17, 41, 64, true, true); + Test(41, 32 * 9 + 17, 64, false, false); + Test(41, 32 * 9 + 17, 64, false, true); + Test(32 * 15 + 17, 63, 128, true, false); + Test(32 * 15 + 17, 63, 128, true, true); + Test(63, 32 * 15 + 17, 128, false, false); + Test(63, 32 * 15 + 17, 128, false, true); + + Test(256, 256, 32, true, false); + Test(256, 256, 32, true, true); + Test(256, 256, 32, false, false); + Test(256, 256, 32, false, true); + } + + MlasBlockwiseQdqTest() = default; +}; + +template <> +MlasBlockwiseQdqTest* MlasTestFixture::mlas_tester(nullptr); + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // ORT_MINIMAL_BUILD