Skip to content

Commit

Permalink
fix test failures in training pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Oct 9, 2023
1 parent 64f5aaf commit 1a8f99e
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 9 deletions.
1 change: 0 additions & 1 deletion onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from onnxruntime.capi._pybind_state import get_device # noqa: F401
from onnxruntime.capi._pybind_state import get_version_string # noqa: F401
from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401
from onnxruntime.capi._pybind_state import quantize_matmul_4bits # noqa: F401
from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401
from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401
from onnxruntime.capi._pybind_state import set_seed # noqa: F401
Expand Down
16 changes: 10 additions & 6 deletions onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <vector>

Check warning on line 8 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h#L8

Found C++ system header after other header. Should be: dequantize_blockwise.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:8:  Found C++ system header after other header. Should be: dequantize_blockwise.h, c system, c++ system, other.  [build/include_order] [4]

#include "core/common/safeint.h"
#include "core/framework/float16.h"
#include "core/platform/threadpool.h"
#include <iostream>

Check warning on line 13 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h#L13

Found C++ system header after other header. Should be: dequantize_blockwise.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:13:  Found C++ system header after other header. Should be: dequantize_blockwise.h, c system, c++ system, other.  [build/include_order] [4]
Expand Down Expand Up @@ -45,10 +46,11 @@ void QuantizeBlockwise(
int32_t k_block_idx = static_cast<int32_t>(block_idx % block_per_K);
int32_t k = k_block_idx * block_size;
BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = dst_blob + block_idx;
size_t offset = SafeInt<size_t>(k) * N + n;
if (nullptr != zero_points_tmp_ptr) {
blob_ptr->quant(src + k * N + n, scale[block_idx], zero_points_tmp_ptr[block_idx], k, K, N);
blob_ptr->quant(src + offset, scale[block_idx], zero_points_tmp_ptr[block_idx], k, K, N);
} else {
blob_ptr->quant(src + k * N + n, scale[block_idx], k, K, N);
blob_ptr->quant(src + offset, scale[block_idx], k, K, N);
}
},
0);
Expand Down Expand Up @@ -119,17 +121,19 @@ void DequantizeBlockwise(
int32_t k_block_idx = static_cast<int32_t>(task_idx % block_per_K);
int32_t k = k_block_idx * block_size;
const BlockwiseQuantBlock<T, block_size, bits>* blob_ptr = src_blob + task_idx;
size_t offset = SafeInt<size_t>(n) * K + k;
if (nullptr != zero_points) {
// if bits >= 4
if constexpr (bits > 4) { // zero point is stored with a byte
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zero_points[task_idx], k, K);
blob_ptr->dequant(dst + offset, scale[task_idx], zero_points[task_idx], k, K);
} else { // zero points is stored with 4bits
uint8_t zp = zero_points[task_idx / 2];
zp = (task_idx & 1) ? (zp >> 4) : (zp & 0xf);
blob_ptr->dequant(dst + n * K + k, scale[task_idx], zp, k, K);
blob_ptr->dequant(dst + offset, scale[task_idx], zp, k, K);
}
} else {
blob_ptr->dequant(dst + n * K + k, scale[task_idx], k, K);
}
else {

Check warning on line 135 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h#L135

An else should appear on the same line as the preceding } [whitespace/newline] [4]
Raw output
onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:135:  An else should appear on the same line as the preceding }  [whitespace/newline] [4]

Check warning on line 135 in onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h#L135

If an else has a brace on one side, it should have it on both [readability/braces] [5]
Raw output
onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h:135:  If an else has a brace on one side, it should have it on both  [readability/braces] [5]
blob_ptr->dequant(dst + offset, scale[task_idx], k, K);
}
},
0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import onnx
from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto

import onnxruntime as ort
from onnxruntime.capi._pybind_state import quantize_matmul_4bits

from .onnx_model import ONNXModel
from .quant_utils import attribute_to_kwarg
Expand Down Expand Up @@ -62,7 +62,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype)
zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8")
ort.quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric)
quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric)

return (packed, scales, zero_point)

Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/test/python/quantization/test_op_matmul_4bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import tempfile
import unittest
from importlib.util import find_spec
from pathlib import Path
from typing import Dict, Tuple, Union

Expand Down Expand Up @@ -136,6 +137,9 @@ def quant_test(
else:
raise exception

@unittest.skipIf(
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
)
def test_quantize_matmul_int4_symmetric(self):
np.random.seed(13)

Expand All @@ -144,6 +148,9 @@ def test_quantize_matmul_int4_symmetric(self):
data_reader = self.input_feeds(1, {"input": [100, 52]})
self.quant_test(model_fp32_path, data_reader, 32, True)

@unittest.skipIf(
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
)
def test_quantize_matmul_int4_offsets(self):
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
self.construct_model_matmul(model_fp32_path, symmetric=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# --------------------------------------------------------------------------

import unittest
from importlib.util import find_spec

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -96,6 +97,9 @@ def quantize_blockwise_4bits_target(matrix_float: npt.ArrayLike, block_size: int


class TestQuantizeBlockwise4Bits(unittest.TestCase):
@unittest.skipIf(
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
)
def test_quantize_blockwise_4bits(self):
for rows, cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]:
for block_size in [16, 32, 64, 128]:
Expand Down

0 comments on commit 1a8f99e

Please sign in to comment.