Skip to content

Commit

Permalink
Augment blockwise quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Oct 25, 2023
1 parent 35ecce4 commit 7c8b095
Showing 5 changed files with 919 additions and 0 deletions.
89 changes: 89 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
@@ -229,3 +229,92 @@ MlasQ8Q4GemmBatch(
const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool
);


////////////////////////////////////////////////////////////
// Blockwise quantization and dequantization where quantization
// parameters are packed into seperate buffers.

Check warning on line 236 in onnxruntime/core/mlas/inc/mlas_q4.h

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "seperate" is a misspelling of "separate" Raw Output: ./onnxruntime/core/mlas/inc/mlas_q4.h:236:30: "seperate" is a misspelling of "separate"
//

/**
* @brief For quantization type <T, block_size, columnwise>, and
* matrix shape [rows, columns], compute the shape of the
* quantization parameter matrix [meta_rows, meta_cols]
*/
template <typename T>
void
MlasBlockwiseQuantMetaShape(
int block_size,
bool columnwise,
int rows,
int columns,
int& meta_rows,
int& meta_cols
);

Check warning on line 253 in onnxruntime/core/mlas/inc/mlas_q4.h

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/inc/mlas_q4.h:253: Closing ) should be moved to the previous line [whitespace/parens] [2]


/**
* @brief Blockwise 4 bits quantization, resulting elements and quantization
* parameters (scales, zero points) are packed into seperate matrices

Check warning on line 258 in onnxruntime/core/mlas/inc/mlas_q4.h

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "seperate" is a misspelling of "separate" Raw Output: ./onnxruntime/core/mlas/inc/mlas_q4.h:258:59: "seperate" is a misspelling of "separate"
* all in column major layout for faster access during subsequent matrix
* multiplication.
*
* @tparam ElementT type of the input matrix element, usually floating point
*
* @param dst points to the quantized matrix, shape [rows, columns] column major
* @param scales points to the scales matrix, column major
* @param zero_points points to the zero_points matrix, column major
* @param src points to the floating point matrix, to be quantized, row major shape [rows, columns]
* @param block_size size of the block to quantize, elements from the same block share the same scale and zero point
* @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row
* @param rows
* @param columns
* @param leading_dimension
* @param thread_pool
*/
template <typename ElementT>
void
MlasQuantizeBlockwise(
uint8_t* dst,
ElementT* scales,
uint8_t* zero_points,
const ElementT* src,
int32_t block_size,
bool columnwise,
int32_t rows,
int32_t columns,
int32_t leading_dimension,
MLAS_THREADPOOL* thread_pool
);

Check warning on line 288 in onnxruntime/core/mlas/inc/mlas_q4.h

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/inc/mlas_q4.h:288: Closing ) should be moved to the previous line [whitespace/parens] [2]


/**
* @brief Blockwise 4 bits dequantization, quantized elements and quantization
* parameters (scales, zero points) are from seperate matrices packed

Check warning on line 293 in onnxruntime/core/mlas/inc/mlas_q4.h

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "seperate" is a misspelling of "separate" Raw Output: ./onnxruntime/core/mlas/inc/mlas_q4.h:293:52: "seperate" is a misspelling of "separate"
* in column major layout. Output is a floating point matrix in column
* major layout for faster access during subsequent matrix multiplication.
*
* @tparam ElementT type of the dequantized matrix element, usually floating point
* @param dst points to dequantized matrix shape [rows, columns] column major
* @param src points to quantized matrix, column major
* @param scales points to quantization scales, column major
* @param zero_points points to quantization zero points, column major
* @param block_size size of the block to quantize, elements from the same block share the same scale and zero point
* @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row
* @param rows
* @param columns
* @param thread_pool
*/
template <typename ElementT>
void
MlasDequantizeBlockwise(
ElementT* dst,
const uint8_t* src,
const ElementT* scales,
const uint8_t* zero_points,
int32_t block_size,
bool columnwise,
int32_t rows,
int32_t columns,
MLAS_THREADPOOL* thread_pool
);

Check warning on line 320 in onnxruntime/core/mlas/inc/mlas_q4.h

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/inc/mlas_q4.h:320: Closing ) should be moved to the previous line [whitespace/parens] [2]

Check warning on line 320 in onnxruntime/core/mlas/inc/mlas_q4.h

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Could not find a newline character at the end of the file. [whitespace/ending_newline] [5] Raw Output: onnxruntime/core/mlas/inc/mlas_q4.h:320: Could not find a newline character at the end of the file. [whitespace/ending_newline] [5]
17 changes: 17 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
@@ -1069,6 +1069,23 @@ MlasTrySimpleParallel(
const std::function<void(std::ptrdiff_t tid)>& Work
);


/**
* @brief Distribute many iterations of work over a thread pool if supported.
* This function is for small workloads in non-performance critical situation.
*
* @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP
* @param Iterations [IN] Total number of iterations

Check warning on line 1078 in onnxruntime/core/mlas/lib/mlasi.h

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/lib/mlasi.h:1078: Closing ) should be moved to the previous line [whitespace/parens] [2]
* @param Work [IN] Logic for computing a range of iterations [begin, end)
*/
void
MlasTryBatchParallel(
MLAS_THREADPOOL * ThreadPool,
const std::ptrdiff_t Iterations,
const std::function<void(std::ptrdiff_t tid)>& Work
);


inline
ptrdiff_t
MlasGetMaximumThreadCount(
Loading

0 comments on commit 7c8b095

Please sign in to comment.