Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augment blockwise quantization #18101

Merged
merged 9 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,117 @@
const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool
);


////////////////////////////////////////////////////////////
// Blockwise quantization and dequantization where quantization
// parameters are packed into separate buffers.
//

/**
* @brief For quantization type <T, block_size, columnwise>, and
* matrix shape [rows, columns], compute the shape of the
* quantization parameter matrix [meta_rows, meta_cols]
*/
template <typename T>
void
MlasBlockwiseQuantMetaShape(
int block_size,
bool columnwise,
int rows,
int columns,
int& meta_rows,
int& meta_cols
);

Check warning on line 253 in onnxruntime/core/mlas/inc/mlas_q4.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/inc/mlas_q4.h:253: Closing ) should be moved to the previous line [whitespace/parens] [2]

/**
* @brief For quantization type <T, block_size, columnwise>, and
* matrix shape [rows, columns], compute the shape of the
* quantized matrix [q_rows, q_cols]. The quantized matrix
* is in column major layout, with bits packed on the column.
*
* @tparam T
* @param block_size
* @param columnwise
* @param rows
* @param columns
* @param q_rows
* @param q_cols
*/
template <typename T>
void
MlasBlockwiseQuantizedShape(
int block_size,
bool columnwise,
int rows,
int columns,
int& q_rows,
int& q_cols
);

Check warning on line 278 in onnxruntime/core/mlas/inc/mlas_q4.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/inc/mlas_q4.h:278: Closing ) should be moved to the previous line [whitespace/parens] [2]


/**
* @brief Blockwise 4 bits quantization, resulting elements and quantization
* parameters (scales, zero points) are packed into separate matrices
* all in column major layout for faster access during subsequent matrix
* multiplication.
*
* @tparam ElementT type of the input matrix element, usually floating point
*
* @param dst points to the quantized matrix, shape [rows, columns] column major
* @param scales points to the scales matrix, column major
* @param zero_points points to the zero_points matrix, column major
* @param src points to the floating point matrix, to be quantized, row major shape [rows, columns]
* @param block_size size of the block to quantize, elements from the same block share the same scale and zero point
* @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row
* @param rows
* @param columns
* @param leading_dimension
* @param thread_pool
*/
template <typename ElementT>
void
MlasQuantizeBlockwise(
uint8_t* dst,
ElementT* scales,
uint8_t* zero_points,
const ElementT* src,
int block_size,
bool columnwise,
int rows,
int columns,
int leading_dimension,
MLAS_THREADPOOL* thread_pool
);

Check warning on line 313 in onnxruntime/core/mlas/inc/mlas_q4.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/inc/mlas_q4.h:313: Closing ) should be moved to the previous line [whitespace/parens] [2]


/**
* @brief Blockwise 4 bits dequantization, quantized elements and quantization
* parameters (scales, zero points) are from separate matrices packed
* in column major layout. Output is a floating point matrix in column
* major layout for faster access during subsequent matrix multiplication.
*
* @tparam ElementT type of the dequantized matrix element, usually floating point
* @param dst points to dequantized matrix shape [rows, columns] column major
* @param src points to quantized matrix, column major
* @param scales points to quantization scales, column major
* @param zero_points points to quantization zero points, column major
* @param block_size size of the block to quantize, elements from the same block share the same scale and zero point
* @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row
* @param rows
* @param columns
* @param thread_pool
*/
template <typename ElementT>
void
MlasDequantizeBlockwise(
ElementT* dst,
const uint8_t* src,
const ElementT* scales,
const uint8_t* zero_points,
int block_size,
bool columnwise,
int rows,
int columns,
MLAS_THREADPOOL* thread_pool
);

Check warning on line 345 in onnxruntime/core/mlas/inc/mlas_q4.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/inc/mlas_q4.h:345: Closing ) should be moved to the previous line [whitespace/parens] [2]
17 changes: 17 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,23 @@
const std::function<void(std::ptrdiff_t tid)>& Work
);


/**
* @brief Distribute many iterations of work over a thread pool if supported.
* This function is for small workloads in non-performance critical situation.
*
* @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP
* @param Iterations [IN] Total number of iterations

Check warning on line 1078 in onnxruntime/core/mlas/lib/mlasi.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/core/mlas/lib/mlasi.h:1078: Closing ) should be moved to the previous line [whitespace/parens] [2]
* @param Work [IN] Logic for computing a range of iterations [begin, end)
*/
void
MlasTryBatchParallel(
MLAS_THREADPOOL * ThreadPool,
const std::ptrdiff_t Iterations,
const std::function<void(std::ptrdiff_t tid)>& Work
);


inline
ptrdiff_t
MlasGetMaximumThreadCount(
Expand Down
Loading
Loading