Skip to content

Commit

Permalink
fix some stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Nov 4, 2023
1 parent 448c4e5 commit 90f2ab5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) {
// number of bytes or elements between adjacent matrices
size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes;
MlasBlockwiseQuantizedBufferSizes(nbits_, block_size_, /* columnwise */ true,
MlasBlockwiseQuantizedBufferSizes(static_cast<int>(nbits_), static_cast<int>(block_size_), /* columnwise */ true,
static_cast<int>(K), static_cast<int>(N),
b_data_matrix_stride_in_bytes, b_scale_matrix_stride,
&b_zero_point_matrix_stride_in_bytes);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/q4_dq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ MlasBlockwiseQuantizedBufferSizes(
{

Check warning on line 805 in onnxruntime/core/mlas/lib/q4_dq.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/mlas/lib/q4_dq.cpp#L805

{ should almost always be at the end of the previous line [whitespace/braces] [4]
Raw output
onnxruntime/core/mlas/lib/q4_dq.cpp:805:  { should almost always be at the end of the previous line  [whitespace/braces] [4]
q_data_size_in_bytes = q_scale_num_elements = 0;
if (q_zero_point_size_in_bytes) {
q_zero_point_size_in_bytes = 0;
*q_zero_point_size_in_bytes = 0;
}

if (qbits == 4) {
Expand Down
14 changes: 9 additions & 5 deletions onnxruntime/test/mlas/unittest/test_blockq4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@ class MlasBlockwiseQdqTest : public MlasTestBase {

int meta_rows;
int meta_cols;
MlasBlockwiseQuantMetaShape<float>(block_size, columnwise, rows, columns, meta_rows, meta_cols);
MlasBlockwiseQuantMetaShape<float, 4>(block_size, columnwise, rows, columns, meta_rows, meta_cols);

int q_rows;
int q_cols;
MlasBlockwiseQuantizedShape<float>(block_size, columnwise, rows, columns, q_rows, q_cols);
MlasBlockwiseQuantizedShape<float, 4>(block_size, columnwise, rows, columns, q_rows, q_cols);

uint8_t* elements = InputElements.GetBuffer(q_rows * q_cols, true);
size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes;
MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise, rows, columns,
q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes);

uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true);

int v = 7;
for (int c = 0; c < columns; c++) {
Expand All @@ -70,8 +74,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase {
}
}

float* scales = InputScales.GetBuffer(meta_rows * meta_cols);
uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true);
float* scales = InputScales.GetBuffer(q_scale_size);
uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(q_zp_size_in_bytes, true);
if (zp) {
for (int c = 0; c < meta_cols; c++) {
for (int r = 0; r < meta_rows; r += 2) {
Expand Down

0 comments on commit 90f2ab5

Please sign in to comment.