Skip to content

Commit

Permalink
Augment blockwise quantization (microsoft#18101)
Browse files Browse the repository at this point in the history
### Description
Augment block wise 4b quantization -- plain CPU impl

### Motivation and Context

Allow column wise or row wise blocks. Experiments show row wise
quantization in LLM weight matrices achieves better precision.

Added tests for quantization and dequantization code.
  • Loading branch information
chenfucn authored and kleiti committed Mar 22, 2024
1 parent 730f86d commit 46ec9d9
Show file tree
Hide file tree
Showing 5 changed files with 1,023 additions and 0 deletions.
114 changes: 114 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,117 @@ MlasQ8Q4GemmBatch(
const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool
);


////////////////////////////////////////////////////////////
// Blockwise quantization and dequantization where quantization
// parameters are packed into separate buffers.
//

/**
* @brief For quantization type <T, block_size, columnwise>, and
* matrix shape [rows, columns], compute the shape of the
* quantization parameter matrix [meta_rows, meta_cols]
*/
template <typename T>
void
MlasBlockwiseQuantMetaShape(
int block_size,
bool columnwise,
int rows,
int columns,
int& meta_rows,
int& meta_cols
);

/**
* @brief For quantization type <T, block_size, columnwise>, and
* matrix shape [rows, columns], compute the shape of the
* quantized matrix [q_rows, q_cols]. The quantized matrix
* is in column major layout, with bits packed on the column.
*
* @tparam T
* @param block_size
* @param columnwise
* @param rows
* @param columns
* @param q_rows
* @param q_cols
*/
template <typename T>
void
MlasBlockwiseQuantizedShape(
int block_size,
bool columnwise,
int rows,
int columns,
int& q_rows,
int& q_cols
);


/**
* @brief Blockwise 4 bits quantization, resulting elements and quantization
* parameters (scales, zero points) are packed into separate matrices
* all in column major layout for faster access during subsequent matrix
* multiplication.
*
* @tparam ElementT type of the input matrix element, usually floating point
*
* @param dst points to the quantized matrix, shape [rows, columns] column major
* @param scales points to the scales matrix, column major
* @param zero_points points to the zero_points matrix, column major
* @param src points to the floating point matrix, to be quantized, row major shape [rows, columns]
* @param block_size size of the block to quantize, elements from the same block share the same scale and zero point
* @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row
* @param rows
* @param columns
* @param leading_dimension
* @param thread_pool
*/
template <typename ElementT>
void
MlasQuantizeBlockwise(
uint8_t* dst,
ElementT* scales,
uint8_t* zero_points,
const ElementT* src,
int block_size,
bool columnwise,
int rows,
int columns,
int leading_dimension,
MLAS_THREADPOOL* thread_pool
);


/**
* @brief Blockwise 4 bits dequantization, quantized elements and quantization
* parameters (scales, zero points) are from separate matrices packed
* in column major layout. Output is a floating point matrix in column
* major layout for faster access during subsequent matrix multiplication.
*
* @tparam ElementT type of the dequantized matrix element, usually floating point
* @param dst points to dequantized matrix shape [rows, columns] column major
* @param src points to quantized matrix, column major
* @param scales points to quantization scales, column major
* @param zero_points points to quantization zero points, column major
* @param block_size size of the block to quantize, elements from the same block share the same scale and zero point
* @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row
* @param rows
* @param columns
* @param thread_pool
*/
template <typename ElementT>
void
MlasDequantizeBlockwise(
ElementT* dst,
const uint8_t* src,
const ElementT* scales,
const uint8_t* zero_points,
int block_size,
bool columnwise,
int rows,
int columns,
MLAS_THREADPOOL* thread_pool
);
17 changes: 17 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,23 @@ MlasTrySimpleParallel(
const std::function<void(std::ptrdiff_t tid)>& Work
);


/**
* @brief Distribute many iterations of work over a thread pool if supported.
* This function is for small workloads in non-performance critical situation.
*
* @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP
* @param Iterations [IN] Total number of iterations
* @param Work [IN] Logic for computing a range of iterations [begin, end)
*/
void
MlasTryBatchParallel(
MLAS_THREADPOOL * ThreadPool,
const std::ptrdiff_t Iterations,
const std::function<void(std::ptrdiff_t tid)>& Work
);


inline
ptrdiff_t
MlasGetMaximumThreadCount(
Expand Down
Loading

0 comments on commit 46ec9d9

Please sign in to comment.