Skip to content

Commit

Permalink
Quantization tool: support float 8 with MatMul, support float 16 weig…
Browse files Browse the repository at this point in the history
…hts (#18043)

### Description

Whenever a node QuantizeLinear or DequantizeLinear, the type of the
weights before being quantize must be known to create the scale with the
expected type. Another option would be to add many operator CastLike but
that would push the burden to onnxruntime optimizer.

The PR tries to avoid changing the signature. To do so, it modified the
scale computation to use a numpy array to store the result and not a
python float. The numpy array must be of the same type than the weights
to quantize.

The PR adds many `assert` to check the type of the scale is not a python
type or a float64. This was added to make sure all the code follows the
same logic. These lines were kept for the first review.

DequantizeLinear, QuantizeLinear cannot be tested with onnx==1.15. PR
onnx/onnx#5709 is missing to fix shape
inference. PR onnx/onnx#5473) is missing to
support QLinearMatMul with float 16. That explains why some tests are
disabled with float 16.

### Motivation and Context

The current quantization tool assumes every weight is float 32. For
large models such as LLAMA, it is usually float 16. The quantization
needs to quantize such weights.
  • Loading branch information
xadupre authored and mszhanyi committed Jan 15, 2024
1 parent d2f35a0 commit ffe4df6
Show file tree
Hide file tree
Showing 18 changed files with 1,107 additions and 237 deletions.
95 changes: 65 additions & 30 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@

class TensorData:
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"])
_floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"])

def __init__(self, **kwargs):
for k, v in kwargs.items():
if k not in TensorData._allowed:
raise ValueError(f"Unexpected value {k!r} not in {TensorData._allowed}.")
if k in TensorData._floats:
if not hasattr(v, "dtype"):
raise ValueError(f"Unexpected type {type(v)} for k={k!r}")
if v.dtype not in (np.float16, np.float32):
raise ValueError(f"Unexpected dtype {v.dtype} for k={k!r}")
setattr(self, k, v)

@property
Expand Down Expand Up @@ -171,7 +177,7 @@ def select_tensors_to_calibrate(self, model: ModelProto):
initializer = {init.name for init in model.graph.initializer}

tensors_to_calibrate = set()
tensor_type_to_calibrate = {TensorProto.FLOAT}
tensor_type_to_calibrate = {TensorProto.FLOAT, TensorProto.FLOAT16}

for node in model.graph.node:
if not self.op_types_to_calibrate or node.op_type in self.op_types_to_calibrate:
Expand Down Expand Up @@ -284,7 +290,17 @@ def add_reduce_min_max(tensor_name, reduce_op_name):
)

self.model.graph.node.extend([reduce_node, reshape_node])
self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, TensorProto.FLOAT, [1]))
value_infos = {vi.name: vi for vi in self.model.graph.value_info}
value_infos.update({o.name: o for o in self.model.graph.output})
value_infos.update({i.name: i for i in self.model.graph.input})
if tensor_name in value_infos:
onnx_type = value_infos[tensor_name].type.tensor_type.elem_type
else:
raise ValueError(
f"Unable to guess tensor type for tensor {tensor_name!r}, "
f"running shape inference before quantization may resolve this issue."
)
self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, onnx_type, [1]))

for tensor in tensors:
add_reduce_min_max(tensor, "ReduceMin")
Expand Down Expand Up @@ -364,24 +380,18 @@ def compute_data(self) -> TensorsData:

pairs = []
for i in range(0, len(added_output_names), 2):
min_value = 0
max_value = 0
if self.moving_average:
min_value_array = np.mean(merged_added_output_dict[added_output_names[i]], axis=0)
max_value_array = np.mean(merged_added_output_dict[added_output_names[i + 1]], axis=0)
else:
min_value_array = min(merged_added_output_dict[added_output_names[i]])
max_value_array = max(merged_added_output_dict[added_output_names[i + 1]])
if isinstance(min_value_array, int) or min_value_array.size > 0:
min_value = float(min_value_array)
if isinstance(max_value_array, int) or max_value_array.size > 0:
max_value = float(max_value_array)
min_value_array = np.min(merged_added_output_dict[added_output_names[i]], axis=0)
max_value_array = np.max(merged_added_output_dict[added_output_names[i + 1]], axis=0)

if self.symmetric:
max_absolute_value = max(abs(min_value), abs(max_value))
max_absolute_value = max(np.abs(min_value_array), np.abs(max_value_array))
pairs.append(tuple([-max_absolute_value, max_absolute_value]))
else:
pairs.append(tuple([min_value, max_value]))
pairs.append(tuple([min_value_array, max_value_array]))

new_calibrate_tensors_range = TensorsData(CalibrationMethod.MinMax, dict(zip(calibrate_tensor_names, pairs)))
if self.calibrate_tensors_range:
Expand Down Expand Up @@ -679,36 +689,57 @@ def collect_absolute_value(self, name_to_arr):
Collect histogram on absolute value
"""
for tensor, data_arr in name_to_arr.items():
data_arr = np.asarray(data_arr) # noqa: PLW2901
data_arr = data_arr.flatten() # noqa: PLW2901
if data_arr.size > 0:
min_value = np.min(data_arr)
max_value = np.max(data_arr)
if isinstance(data_arr, list):
for arr in data_arr:
if not isinstance(arr, np.ndarray):
raise ValueError(f"Unexpected type {type(arr)} for tensor={tensor!r}")
dtypes = set(a.dtype for a in arr)
if len(dtypes) != 1:
raise ValueError(
f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
)
data_arr_np = np.asarray(data_arr)
elif not isinstance(data_arr, np.ndarray):
raise ValueError(f"Unexpected type {type(data_arr)} for tensor={tensor!r}")
else:
data_arr_np = data_arr
data_arr_np = data_arr_np.flatten()
if data_arr_np.size > 0:
min_value = np.min(data_arr_np)
max_value = np.max(data_arr_np)
else:
min_value = 0
max_value = 0

data_arr = np.absolute(data_arr) # only consider absolute value # noqa: PLW2901
data_arr_np = np.absolute(data_arr_np) # only consider absolute value

if tensor not in self.histogram_dict:
# first time it uses num_bins to compute histogram.
hist, hist_edges = np.histogram(data_arr, bins=self.num_bins)
hist, hist_edges = np.histogram(data_arr_np, bins=self.num_bins)
hist_edges = hist_edges.astype(data_arr_np.dtype)
assert (
data_arr_np.dtype != np.float64
), "only float32 or float16 is supported, every constant must be explicetly typed"
self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value)
else:
old_histogram = self.histogram_dict[tensor]
old_min = old_histogram[2]
old_max = old_histogram[3]
old_hist = old_histogram[0]
old_hist_edges = old_histogram[1]
temp_amax = np.max(data_arr)
temp_amax = np.max(data_arr_np)
if temp_amax > old_hist_edges[-1]:
# increase the number of bins
width = old_hist_edges[1] - old_hist_edges[0]
# NOTE: np.arange may create an extra bin after the one containing temp_amax
new_bin_edges = np.arange(old_hist_edges[-1] + width, temp_amax + width, width)
old_hist_edges = np.hstack((old_hist_edges, new_bin_edges))
hist, hist_edges = np.histogram(data_arr, bins=old_hist_edges)
hist, hist_edges = np.histogram(data_arr_np, bins=old_hist_edges)
hist_edges = hist_edges.astype(data_arr_np.dtype)
hist[: len(old_hist)] += old_hist
assert (
data_arr_np.dtype != np.float64
), "only float32 or float16 is supported, every constant must be explicetly typed"
self.histogram_dict[tensor] = (hist, hist_edges, min(old_min, min_value), max(old_max, max_value))

def collect_value(self, name_to_arr):
Expand All @@ -723,8 +754,8 @@ def collect_value(self, name_to_arr):
min_value = np.min(data_arr)
max_value = np.max(data_arr)
else:
min_value = 0
max_value = 0
min_value = np.array(0, dtype=data_arr.dtype)
max_value = np.array(0, dtype=data_arr.dtype)

threshold = max(abs(min_value), abs(max_value))

Expand Down Expand Up @@ -811,16 +842,16 @@ def compute_percentile(self):
idx_right = np.searchsorted(cdf, percentile / 100.0)

thresholds_dict[tensor] = (
-float(hist_edges[idx_right]),
float(hist_edges[idx_right]),
-np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
)
else:
percent_to_cut_one_side = (100.0 - percentile) / 200.0
idx_right = np.searchsorted(cdf, 1.0 - percent_to_cut_one_side)
idx_left = np.searchsorted(cdf, percent_to_cut_one_side)
thresholds_dict[tensor] = (
float(hist_edges[idx_left]),
float(hist_edges[idx_right]),
np.array(hist_edges[idx_left], dtype=hist_edges.dtype),
np.array(hist_edges[idx_right], dtype=hist_edges.dtype),
)
min_value = histogram[2]
max_value = histogram[3]
Expand Down Expand Up @@ -868,19 +899,19 @@ def _avg_std(hist, hist_edges, power=1):
if power == 1:
avg = (hist * values).sum() / hist.sum()
std = ((hist * values**2).sum() / hist.sum() - avg**2) ** 0.5
return avg, std
return np.array(avg, dtype=hist_edges.dtype), np.array(std, dtype=hist_edges.dtype)
if int(power) == power and int(power) % 2 == 1:
avg = (hist * values**power).sum() / hist.sum()
std = ((hist * (values**power - avg) ** 2).sum() / hist.sum()) ** 0.5
return avg, std
return np.array(avg, dtype=hist_edges.dtype), np.array(std, dtype=hist_edges.dtype)

fact = np.abs(values) / values
fact[np.isnan(fact)] = 1
fact[np.isinf(fact)] = 1
values = np.abs(values) ** power * fact
avg = (hist * values).sum() / hist.sum()
std = ((hist * values**2).sum() / hist.sum() - avg**2) ** 0.5
return avg, std
return np.array(avg, dtype=hist_edges.dtype), np.array(std, dtype=hist_edges.dtype)

def compute_distribution(self):
if self.num_bins < 512:
Expand All @@ -897,12 +928,16 @@ def compute_distribution(self):
hist = histogram[0]
hist_edges = histogram[1]

assert hist_edges.dtype != np.float64
if self.scenario == "same":
avg_coef, std_coef = self._avg_std(hist, hist_edges, power=1)
elif self.scenario == "p3":
avg_coef, std_coef = self._avg_std(hist, hist_edges, power=1.0 / 3.0)
else:
raise ValueError("Invalid scenario. Must be in {'same', 'p3'}.")
assert avg_coef.dtype != np.float64
assert std_coef.dtype != np.float64
assert hist_edges.dtype != np.float64
thresholds_dict[tensor] = TensorData(avg=avg_coef, std=std_coef, hist=hist, hist_edges=hist_edges)

# Plot histogram for debug only
Expand Down
Loading

0 comments on commit ffe4df6

Please sign in to comment.