Skip to content

Commit

Permalink
Merge branch 'main' into Cjian/build.py-cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jchen351 committed Jan 18, 2024
2 parents fafe56d + 63dd605 commit 1bc451a
Show file tree
Hide file tree
Showing 41 changed files with 217 additions and 95 deletions.
1 change: 0 additions & 1 deletion build_arm64x.bat
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

setlocal
set PATH=C:\Program Files\Git\usr\bin;%PATH%
set LINK_REPRO_NAME=/mylink.rsp

rem Requires a Python install to be available in your PATH
python "%~dp0\tools\ci_build\build.py" --arm64 --buildasx --build_dir "%~dp0\build\arm64-x" %*
Expand Down
5 changes: 5 additions & 0 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ if (onnxruntime_DISABLE_RTTI)
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/GR->" "$<$<COMPILE_LANGUAGE:CXX>:/we4541>")
else()
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:-fno-rtti>")
if (onnxruntime_USE_WEBNN)
# Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled
# in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/7001
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0>")
endif()
endif()
else()
#MSVC RTTI flag /GR is not added to CMAKE_CXX_FLAGS by default. But, anyway VC++2019 treats "/GR" default on.
Expand Down
5 changes: 4 additions & 1 deletion cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ else()
endif()

if (onnxruntime_USE_WEBNN)
set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT")
set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT")
if (onnxruntime_DISABLE_RTTI)
set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -fno-rtti -DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0")
endif()
endif()

# Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions.
Expand Down
86 changes: 67 additions & 19 deletions onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# license information.
# --------------------------------------------------------------------------
import abc
import copy
import itertools
import os
import uuid
Expand All @@ -21,6 +22,48 @@
from .quant_utils import apply_plot, load_model_with_shape_infer, smooth_distribution


def rel_entr(pk: np.ndarray, qk: np.ndarray) -> np.ndarray:
"""
See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html#scipy.special.rel_entr.
Python implementation.
"""
res = np.empty(pk.shape, dtype=pk.dtype)
res[:] = pk[:] * np.log(pk[:] / qk[:])
c2 = (pk == 0) & (qk >= 0)
res[c2] = 0
c1 = (pk > 0) & (qk > 0)
res[~c1] = np.inf
return res


def entropy(
pk: np.ndarray,
qk: np.ndarray,
base: Optional[float] = None,
axis: int = 0,
) -> np.ndarray:
"""
Simplifeied version of entropy.
Source: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html.
This avoids taking a dependency on scipy just for this function.
"""
assert base is None or base > 0, "base={base} must be a positive number or `None`."
assert qk is not None, "qk is None"

pk = np.asarray(pk).astype(np.float32)
pk = 1.0 * pk / np.sum(pk, axis=axis, keepdims=True)

qk = np.asarray(qk).astype(np.float32)
pk, qk = np.broadcast_arrays(pk, qk)
qk = 1.0 * qk / np.sum(qk, axis=axis, keepdims=True)
vec = rel_entr(pk, qk)

s = np.sum(vec, axis=axis)
if base is not None:
s /= np.log(base)
return s.astype(pk.dtype)


class TensorData:
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"])
_floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"])
Expand Down Expand Up @@ -708,8 +751,8 @@ def collect_absolute_value(self, name_to_arr):
min_value = np.min(data_arr_np)
max_value = np.max(data_arr_np)
else:
min_value = 0
max_value = 0
min_value = np.array(0, dtype=data_arr_np.dtype)
max_value = np.array(0, dtype=data_arr_np.dtype)

data_arr_np = np.absolute(data_arr_np) # only consider absolute value

Expand All @@ -725,6 +768,8 @@ def collect_absolute_value(self, name_to_arr):
old_histogram = self.histogram_dict[tensor]
old_min = old_histogram[2]
old_max = old_histogram[3]
assert hasattr(old_min, "dtype"), f"old_min should be a numpy array but is {type(old_min)}"
assert hasattr(old_max, "dtype"), f"old_min should be a numpy array but is {type(old_max)}"
old_hist = old_histogram[0]
old_hist_edges = old_histogram[1]
temp_amax = np.max(data_arr_np)
Expand Down Expand Up @@ -757,7 +802,7 @@ def collect_value(self, name_to_arr):
min_value = np.array(0, dtype=data_arr.dtype)
max_value = np.array(0, dtype=data_arr.dtype)

threshold = max(abs(min_value), abs(max_value))
threshold = np.array(max(abs(min_value), abs(max_value)), dtype=data_arr.dtype)

if tensor in self.histogram_dict:
old_histogram = self.histogram_dict[tensor]
Expand Down Expand Up @@ -809,7 +854,7 @@ def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_thresho
def compute_collection_result(self):
if not self.histogram_dict or len(self.histogram_dict) == 0:
raise ValueError("Histogram has not been collected. Please run collect() first.")
print(f"Finding optimal threshold for each tensor using {self.method} algorithm ...")
print(f"Finding optimal threshold for each tensor using {self.method!r} algorithm ...")

if self.method == "entropy":
return self.compute_entropy()
Expand Down Expand Up @@ -938,7 +983,14 @@ def compute_distribution(self):
assert avg_coef.dtype != np.float64
assert std_coef.dtype != np.float64
assert hist_edges.dtype != np.float64
thresholds_dict[tensor] = TensorData(avg=avg_coef, std=std_coef, hist=hist, hist_edges=hist_edges)
thresholds_dict[tensor] = TensorData(
avg=avg_coef,
std=std_coef,
hist=hist,
hist_edges=hist_edges,
lowest=hist_edges.min(),
highest=hist_edges.max(),
)

# Plot histogram for debug only
if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"):
Expand All @@ -952,18 +1004,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
`q` is a truncated version of the original distribution.
Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
"""
import copy

from scipy.stats import entropy

hist = histogram[0]
hist_edges = histogram[1]
num_bins = hist.size
zero_bin_index = num_bins // 2
num_half_quantized_bin = num_quantized_bins // 2

dtype = histogram[1].dtype
kl_divergence = np.zeros(zero_bin_index - num_half_quantized_bin + 1)
thresholds = [(0, 0) for i in range(kl_divergence.size)]
thresholds = [(np.array(0, dtype=dtype), np.array(0, dtype=dtype)) for i in range(kl_divergence.size)]

# <------------ num bins ---------------->
# <--- quantized bins ---->
Expand All @@ -983,10 +1032,7 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
start_index = zero_bin_index - i
end_index = zero_bin_index + i + 1 if (zero_bin_index + i + 1) <= num_bins else num_bins

thresholds[i - num_half_quantized_bin] = (
float(hist_edges[start_index]),
float(hist_edges[end_index]),
)
thresholds[i - num_half_quantized_bin] = (hist_edges[start_index], hist_edges[end_index])

sliced_distribution = copy.deepcopy(hist[start_index:end_index])

Expand Down Expand Up @@ -1020,15 +1066,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):

norm = sum(nonzeros[start:end])
if norm != 0:
q[start:end] = float(quantized_bins[index]) / float(norm)
q[start:end] = quantized_bins[index] / norm

p = smooth_distribution(p)
q = smooth_distribution(q)

if isinstance(q, np.ndarray):
kl_divergence[i - num_half_quantized_bin] = entropy(p, q)
if p is None or q is None:
div = np.array(np.inf, dtype=dtype)
else:
kl_divergence[i - num_half_quantized_bin] = float("inf")
div = np.array(entropy(p, q), dtype=dtype)
kl_divergence[i - num_half_quantized_bin] = div

min_kl_divergence_idx = np.argmin(kl_divergence)
optimal_threshold = thresholds[min_kl_divergence_idx]
Expand All @@ -1038,6 +1084,8 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
optimal_threshold = (min_value, optimal_threshold[1])
if optimal_threshold[1] > max_value:
optimal_threshold = (optimal_threshold[0], max_value)
assert hasattr(optimal_threshold[0], "dtype")
assert hasattr(optimal_threshold[1], "dtype")
return optimal_threshold


Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/quantization/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def smooth_distribution(p, eps=0.0001):

if not n_nonzeros:
# raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
return -1
return None
eps1 = eps * float(n_zeros) / float(n_nonzeros)
assert eps1 < 1.0, "n_zeros=%d, n_nonzeros=%d, eps1=%f" % (
n_zeros,
Expand Down
21 changes: 16 additions & 5 deletions onnxruntime/python/tools/transformers/large_model_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,24 +224,35 @@ def fetch_onnx_inputs_outputs_name(
if not num_of_past_key:
num_of_past_key = model.config.num_hidden_layers

onnx_inp_names = ("input_ids", "attention_mask")
# filter out constant inputs
onnx_inp_names = tuple(
[torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)]
)
assert (
"input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names
), "input_ids and attention_mask must be existed in inputs"
onnx_out_names = ("logits",)
onnx_dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
}
# add dyanmic dimensions for the unkonw inputs
for idx, name in enumerate(onnx_inp_names):
if name not in onnx_dynamic_axes:
unknown_dims = {i: f"{idx}__unknown_dims__{i}" for i in range(onnx_inputs[idx].dim())}
onnx_dynamic_axes[name] = unknown_dims
if input_with_past:
for i in range(num_of_past_key):
onnx_inp_names += (f"present_key.{i}",)
onnx_inp_names += (f"present_values.{i}",)
onnx_inp_names += (f"past_key_values.{i}.key",)
onnx_inp_names += (f"past_key_values.{i}.value",)

onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis
onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis

if with_past or input_with_past:
for i in range(num_of_past_key):
onnx_out_names += (f"past_key.{i}",)
onnx_out_names += (f"past_values.{i}",)
onnx_out_names += (f"present.{i}.key",)
onnx_out_names += (f"present.{i}.value",)
onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis
onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis

Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/test/providers/qnn/simple_op_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) {
// Check the Onnx skeleton file is generated
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
// Check the Qnn context cache binary file is generated
EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"));
std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin";
EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin));

// 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
Expand All @@ -837,6 +838,10 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) {
QDQTolerance(),
logging::Severity::kERROR,
context_binary_file);

// Clean up
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
ASSERT_EQ(std::remove(qnn_ctx_bin.c_str()), 0);
}

// Run QDQ model on HTP 2 times
Expand Down Expand Up @@ -898,6 +903,9 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) {
ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast<int>(qnn_ctx_model_data.size())));
// Verify the return status with code INVALID_GRAPH
ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH);

// Clean up
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
}

// Run QDQ model on HTP with 2 inputs
Expand Down Expand Up @@ -955,6 +963,8 @@ TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) {
QDQTolerance(),
logging::Severity::kERROR,
context_binary_file);
// Clean up
ASSERT_EQ(std::remove(context_binary_file.c_str()), 0);
}

TEST_F(QnnHTPBackendTests, QuantAccuracyTest) {
Expand Down
Loading

0 comments on commit 1bc451a

Please sign in to comment.