Skip to content

Commit

Permalink
Augment blockwise quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Oct 25, 2023
1 parent 35ecce4 commit b5d03d4
Show file tree
Hide file tree
Showing 6 changed files with 1,163 additions and 13 deletions.
257 changes: 244 additions & 13 deletions onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,221 @@
namespace onnxruntime {
namespace contrib {

template <
int Row_, ///< rows of a matrix
int Column_ ///< columns of a matrix
>
struct Shape2D {
static int const kRow = Row_; ///< rows of a matrix
static int const kColumn = Column_; ///< columns of a matrix
static int const kCount = Row_ * Column_; ///< total number of elements in a matrix

Check warning on line 26 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:26: Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3]
};

/**
* @brief Rectify min/max from a set of weights, and convert to scale and zero point
* for quantization
* @tparam ScaleT type of scale, usually floating point of various bits
* @tparam qbits number of int bits used for zero point value
* @param[in] min
* @param[in] max
* @param[out] scale
* @param[out] zp
*/
template <typename ScaleT, int qbits>
FORCEINLINE
void range2scalezp(float min, float max, ScaleT& scale, uint8_t& zp) {

Check warning on line 42 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:42: Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2]
constexpr int zp_max = (1 << qbits) - 1;
constexpr float zp_max_fp = static_cast<float>(zp_max);

min = std::min(min, 0.0f);
max = std::max(max, 0.0f);

float scale_f = (max - min) / zp_max;

float zero_point_fp = min;
if (scale_f != 0.0f) {
zero_point_fp = 0.f - min / scale_f;
}

if (zero_point_fp < 0.0f) {
zp = 0;
} else if (zero_point_fp > zp_max_fp) {
zp = zp_max;
} else {
zp = (uint8_t)roundf(zero_point_fp);
}
scale = static_cast<ScaleT>(scale_f);
}

template <typename ScaleT, int qbits>
FORCEINLINE void range2scale(float min, float max, ScaleT& scale) {
constexpr int mid_v = 1 << (qbits - 1);
constexpr float mid_fp = static_cast<float>(-mid_v);

max = fabsf(max) > fabsf(min) ? max : min;

scale = static_cast<ScaleT>(max / mid_fp);
};

Check warning on line 74 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:74: You don't need a ; after a } [readability/braces] [4]


/**
* @brief Blockwise quantization methods
* @tparam ElementT source data type, e.g. fp16
* @tparam block_size number of elemenets quantized together
* @tparam qbits number of bits in each quantized element
* @tparam Columnwise true: elements in a block come from one single column
* false: elements in a block come from one single row
*/
template<
typename ElementT,
int32_t block_size,
int32_t qbits,
bool Columnwise>
struct BlockwiseQuantizer {
using QuantBlk = std::conditional_t<Columnwise, Shape2D<block_size, 1>, Shape2D<1, block_size>>;
using ThreadBlk = std::conditional_t<Columnwise, Shape2D<block_size, 8>, Shape2D<8, block_size>>;

static constexpr int zp_max = (1 << qbits) - 1;
static constexpr float zp_max_fp = static_cast<float>(zp_max);

/**
* @brief Quantized a Matrix shape [rows, columns], resulting quantized
* and packed data are stored in column major (transposed)
* @param[out] dst pointer to the quantized weights, column major: [columns, rows]
* @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow]
* @param[out] zero_points pointer to the zero points, same shape as scale
* @param[in] src pointer to the source matrix, row major: [rows, columns]
* @param rows
* @param columns
* @param leadingDimension stride of the source matrix, i.e. distance from one row to the next
*/
void quantizeAndTranspose(
uint8_t* dst,
ElementT* scales,
uint8_t* zero_points,
const ElementT* src,
int32_t rows,
int32_t columns,
int32_t leadingDimension,
onnxruntime::concurrency::ThreadPool* thread_pool) {

Check warning on line 117 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:117: Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2]
// To support other qbits, need to add bit packing code for
// storing to dst and zero points
static_assert(qbits == 4, "Only 4b block quantization is supported!");

// Thread partitioning
const auto thrd_row_blks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow;
const auto thrd_col_blks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn;
const auto total_thrd_blks = thrd_row_blks * thrd_col_blks;

const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow;
const auto col_blks = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn;
const auto total_quant_blks = row_blks * col_blks;

std::vector<uint8_t> zero_points_tmp; // to avoid race condition
(void)zero_points_tmp;
uint8_t* zero_points_tmp_ptr = zero_points;
if (zero_points != nullptr) {
zero_points_tmp.resize(total_quant_blks, 0);
zero_points_tmp_ptr = zero_points_tmp.data();
}

concurrency::ThreadPool::TryBatchParallelFor(
thread_pool,
total_thrd_blks,
[&](ptrdiff_t block_idx) {
int32_t r_blk_idx = static_cast<int32_t>(block_idx / thrd_col_blks);
int32_t c_blk_idx = static_cast<int32_t>(block_idx % thrd_col_blks);

int32_t r = r_blk_idx * ThreadBlk::kRow;
int32_t c = c_blk_idx * ThreadBlk::kColumn;

int32_t r_end = std::min(r + ThreadBlk::kRow, rows);
int32_t c_end = std::min(c + ThreadBlk::kColumn, columns);

float min = std::numeric_limits<float>::max();
float max = -min;
if constexpr (Columnwise) {
for (int32_t j = c; j < c_end; ++j) {
for (int32_t i = r; i < r_end; ++i) {
const float v = static_cast<float>(src[i * leadingDimension + j]);
if (v < min) min = v;
if (v > max) max = v;
}
// store scale and zero point at matrix position (i/QuantBlk::kRow, j/QuantBlk::kColumn)
// it's packed as column major, so the index is (column * stride + row)
// where stride = row_blks
// over all it's (j/QuantBlk::kColumn) * (rows/QuantBlk::kRow) + (i/QuantBlk::kRow)
const int32_t meta_idx = (j / QuantBlk::kColumn) * row_blks + (r / QuantBlk::kRow);
if (zero_points == nullptr) {
range2scale<ElementT, qbits>(min, max, scales[meta_idx]);
}
else {

Check warning on line 169 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 An else should appear on the same line as the preceding } [whitespace/newline] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:169: An else should appear on the same line as the preceding } [whitespace/newline] [4]

Check warning on line 169 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:169: If an else has a brace on one side, it should have it on both [readability/braces] [5]
range2scalezp<ElementT, qbits>(min, max, scales[meta_idx], zero_points_tmp_ptr[meta_idx]);
}
}
} else { // Row-wise
for (int32_t i = r; i < r_end; ++i) {
for (int32_t j = c; j < c_end; ++j) {
const float v = static_cast<float>(src[i * leadingDimension + j]);
if (v < min) min = v;
if (v > max) max = v;
}
// store scale and zero point at matrix position (i/QuantBlk::kRow, j/QuantBlk::kColumn)
// it's packed as column major, so the index is (column * row_blks + row)
// over all it's (j/QuantBlk::kColumn) * row_blks + (i/QuantBlk::kRow)
const int32_t meta_idx = (c / QuantBlk::kColumn) * row_blks + (i / QuantBlk::kRow);
if (zero_points == nullptr) {
range2scale<ElementT, qbits>(min, max, scales[meta_idx]);
} else {
range2scalezp<ElementT, qbits>(min, max, scales[meta_idx], zero_points_tmp_ptr[meta_idx]);
}
}
}

for (int32_t j = c; j < c_end; ++j) {
const int32_t meta_col = j / QuantBlk::kColumn;
for (int32_t i = r; i < r_end; i += 2) {
const int32_t meta_row = i / QuantBlk::kRow;
const int32_t meta_idx = meta_col * row_blks + meta_row;
const float scale = static_cast<float>(scales[meta_idx]);
const float reciprocal_scale = scale ? 1.0f / scale : 0.0f;
const int zp = zero_points == nullptr ? 8 : zero_points_tmp_ptr[meta_idx];

const float v0 = static_cast<float>(src[i * leadingDimension + j]);
const float v1 = (i + 1) < r_end ? static_cast<float>(src[(i + 1) * leadingDimension + j]) : 0.0f;

const uint8_t vi0 = (uint8_t)std::min(zp_max_fp, std::max(0.0f, roundf(v0 * reciprocal_scale + zp)));
const uint8_t vi1 = (uint8_t)std::min(zp_max_fp, std::max(0.0f, roundf(v1 * reciprocal_scale + zp)));

// !! 4b specific code
dst[i / 2] = (vi0 & 0xf) | (vi1 << 4);
}
}
},
0);

if (zero_points != nullptr) { // compact zero points
for (int col_idx = 0; col_idx < col_blks; col_idx++) {
auto* col_ptr = zero_points + col_idx * (row_blks / 2);
auto* src_col_ptr = zero_points_tmp_ptr + col_idx * row_blks;
for (int row_idx = 0; row_idx < row_blks; row_idx += 2) {
const auto zp0 = src_col_ptr[row_idx];
const auto zp1 = (row_idx + 1 < row_blks) ? src_col_ptr[row_idx + 1] : 0;
// !! 4b specific code
col_ptr[row_idx / 2] = ((zp0 & 0xf) | (zp1 << 4));
}
}
}

Check warning on line 226 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:226: Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3]
}

Check warning on line 228 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:228: Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3]
};



template <typename T, int32_t block_size, int32_t bits>
void QuantizeBlockwise(
uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ]
Expand Down Expand Up @@ -80,19 +295,35 @@ void QuantizeBlockwise(
int32_t K,
onnxruntime::concurrency::ThreadPool* thread_pool) {
ORT_ENFORCE(bits == 4, "only 4 bits is supported now");

if (16 == block_size) {
QuantizeBlockwise4Bits(16);
} else if (32 == block_size) {
QuantizeBlockwise4Bits(32);
} else if (64 == block_size) {
QuantizeBlockwise4Bits(64);
} else if (128 == block_size) {
QuantizeBlockwise4Bits(128);
} else if (256 == block_size) {
QuantizeBlockwise4Bits(256);
} else {
ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
switch (block_size) {
case 16: {
BlockwiseQuantizer<T, 16, 4, true> quantizer;
quantizer.quantizeAndTranspose(dst, scale, zero_points, src, K, N, N, thread_pool);
break;
}
case 32: {
BlockwiseQuantizer<T, 32, 4, true> quantizer;

Check warning on line 305 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Tab found; better to use spaces [whitespace/tab] [1] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:305: Tab found; better to use spaces [whitespace/tab] [1]
quantizer.quantizeAndTranspose(dst, scale, zero_points, src, K, N, N, thread_pool);

Check warning on line 306 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Tab found; better to use spaces [whitespace/tab] [1] Raw Output: onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:306: Tab found; better to use spaces [whitespace/tab] [1]
break;
}
case 64: {
BlockwiseQuantizer<T, 64, 4, true> quantizer;
quantizer.quantizeAndTranspose(dst, scale, zero_points, src, K, N, N, thread_pool);
break;
}
case 128: {
BlockwiseQuantizer<T, 128, 4, true> quantizer;
quantizer.quantizeAndTranspose(dst, scale, zero_points, src, K, N, N, thread_pool);
break;
}
case 256: {
BlockwiseQuantizer<T, 256, 4, true> quantizer;
quantizer.quantizeAndTranspose(dst, scale, zero_points, src, K, N, N, thread_pool);
break;
}
default:
ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
break;
}
}

Expand Down
89 changes: 89 additions & 0 deletions onnxruntime/core/mlas/inc/mlas_q4.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,92 @@ MlasQ8Q4GemmBatch(
const MLAS_Q8Q4_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool
);


////////////////////////////////////////////////////////////
// Blockwise quantization and dequantization where quantization
// parameters are packed into seperate buffers.

Check warning on line 236 in onnxruntime/core/mlas/inc/mlas_q4.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "seperate" is a misspelling of "separate" Raw Output: ./onnxruntime/core/mlas/inc/mlas_q4.h:236:30: "seperate" is a misspelling of "separate"
//

/**
* @brief For quantization type <T, block_size, columnwise>, and
* matrix shape [rows, columns], compute the shape of the
* quantization parameter matrix [meta_rows, meta_cols]
*/
template <typename T>
void
MlasBlockwiseQuantMetaShape(
int block_size,
bool columnwise,
int rows,
int columns,
int& meta_rows,
int& meta_cols
);


/**
* @brief Blockwise 4 bits quantization, resulting elements and quantization
* parameters (scales, zero points) are packed into seperate matrices

Check warning on line 258 in onnxruntime/core/mlas/inc/mlas_q4.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "seperate" is a misspelling of "separate" Raw Output: ./onnxruntime/core/mlas/inc/mlas_q4.h:258:59: "seperate" is a misspelling of "separate"
* all in column major layout for faster access during subsequent matrix
* multiplication.
*
* @tparam ElementT type of the input matrix element, usually floating point
*
* @param dst points to the quantized matrix, shape [rows, columns] column major
* @param scales points to the scales matrix, column major
* @param zero_points points to the zero_points matrix, column major
* @param src points to the floating point matrix, to be quantized, row major shape [rows, columns]
* @param block_size size of the block to quantize, elements from the same block share the same scale and zero point
* @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row
* @param rows
* @param columns
* @param leading_dimension
* @param thread_pool
*/
template <typename ElementT>
void
MlasQuantizeBlockwise(
uint8_t* dst,
ElementT* scales,
uint8_t* zero_points,
const ElementT* src,
int32_t block_size,
bool columnwise,
int32_t rows,
int32_t columns,
int32_t leading_dimension,
MLAS_THREADPOOL* thread_pool
);


/**
* @brief Blockwise 4 bits dequantization, quantized elements and quantization
* parameters (scales, zero points) are from seperate matrices packed

Check warning on line 293 in onnxruntime/core/mlas/inc/mlas_q4.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "seperate" is a misspelling of "separate" Raw Output: ./onnxruntime/core/mlas/inc/mlas_q4.h:293:52: "seperate" is a misspelling of "separate"
* in column major layout. Output is a floating point matrix in column
* major layout for faster access during subsequent matrix multiplication.
*
* @tparam ElementT type of the dequantized matrix element, usually floating point
* @param dst points to dequantized matrix shape [rows, columns] column major
* @param src points to quantized matrix, column major
* @param scales points to quantization scales, column major
* @param zero_points points to quantization zero points, column major
* @param block_size size of the block to quantize, elements from the same block share the same scale and zero point
* @param columnwise true when elements in a block are from the same column, false when elements in a block are from the same row
* @param rows
* @param columns
* @param thread_pool
*/
template <typename ElementT>
void
MlasDequantizeBlockwise(
ElementT* dst,
const uint8_t* src,
const ElementT* scales,
const uint8_t* zero_points,
int32_t block_size,
bool columnwise,
int32_t rows,
int32_t columns,
MLAS_THREADPOOL* thread_pool
);
17 changes: 17 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,23 @@ MlasTrySimpleParallel(
const std::function<void(std::ptrdiff_t tid)>& Work
);


/**
* @brief Distribute many iterations of work over a thread pool if supported.
* This function is for small workloads in non-performance critical situation.
*
* @param ThreadPool [IN] Optional thread pool. Ignored when using OpenMP
* @param Iterations [IN] Total number of iterations
* @param Work [IN] Logic for computing a range of iterations [begin, end)
*/
void
MlasTryBatchParallel(
MLAS_THREADPOOL * ThreadPool,
const std::ptrdiff_t Iterations,
const std::function<void(std::ptrdiff_t tid)>& Work
);


inline
ptrdiff_t
MlasGetMaximumThreadCount(
Expand Down
Loading

0 comments on commit b5d03d4

Please sign in to comment.