diff --git a/include/onnxruntime/core/framework/int4.h b/include/onnxruntime/core/framework/int4.h
index 228c1e4e872de..aff365dc9738f 100644
--- a/include/onnxruntime/core/framework/int4.h
+++ b/include/onnxruntime/core/framework/int4.h
@@ -84,11 +84,21 @@ struct Int4x2Base {
return (num_int4_elems + 1) / 2;
}
+ ///
+ /// Copy a source buffer of 4-bit elements (packed) into a destination buffer of 8-bit elements (unpacked).
+ ///
+ /// Destination buffer to store unpacked 8-bit elements
+ /// Source buffer with 4-bit elements
+ /// True on success
static bool Unpack(gsl::span dst, gsl::span> src) {
if (CalcNumInt4Pairs(dst.size()) != src.size()) {
return false;
}
+ if (src.empty()) {
+ return true;
+ }
+
for (size_t i = 0; i < dst.size(); i++) {
size_t r = i >> 1; // i / 2;
size_t c = i & 0x1; // i % 2;
@@ -98,11 +108,21 @@ struct Int4x2Base {
return true;
}
+ ///
+ /// Copy a source buffer of 8-bit elements (unpacked) into a destination buffer of 4-bit elements (packed).
+ ///
+ /// Destination buffer to store packed 4-bit elements
+ /// Source buffer with 8-bit elements
+ /// True on success
static bool Pack(gsl::span> dst, gsl::span src) {
- if (src.empty() || (CalcNumInt4Pairs(src.size()) != dst.size())) {
+ if (CalcNumInt4Pairs(src.size()) != dst.size()) {
return false;
}
+ if (src.empty()) {
+ return true;
+ }
+
size_t src_i = 0;
size_t dst_i = 0;
@@ -116,6 +136,20 @@ struct Int4x2Base {
return true;
}
+
+ ///
+ /// Returns hierarchical indices for a packed int4 element from the given element index.
+ ///
+ /// Usage:
+ /// Int4x2* data = ...;
+ /// auto indices = GetTensorElemIndices(3); // 4th int4 element
+ /// int8_t elem = data[indices.first].GetElem(indices.second);
+ ///
+ /// Index of 4-bit element
+ /// Unpacked element
+ static inline std::pair GetTensorElemIndices(size_t index) {
+ return {index >> 1, index & 0x1};
+ }
};
using Int4x2 = Int4x2Base;
diff --git a/onnxruntime/core/framework/print_tensor_statistics_utils.h b/onnxruntime/core/framework/print_tensor_statistics_utils.h
index fd036114f3e76..65360674e88d0 100644
--- a/onnxruntime/core/framework/print_tensor_statistics_utils.h
+++ b/onnxruntime/core/framework/print_tensor_statistics_utils.h
@@ -79,6 +79,33 @@ void PrintCommonStats(const T* data, size_t count) {
PrintValue(max);
}
+#define DEF_PRINT_COMMON_STATS_INT4(INT4_TYPE) \
+ template <> \
+ inline void PrintCommonStats(const INT4_TYPE* data, size_t count) { \
+ using UnpackedType = typename INT4_TYPE::UnpackedType; \
+ UnpackedType min = data[0].GetElem(0); \
+ UnpackedType max = min; \
+ for (size_t i = 1; i < count; i++) { \
+ auto indices = INT4_TYPE::GetTensorElemIndices(i); \
+ auto value = data[indices.first].GetElem(indices.second); \
+ if (value > max) { \
+ max = value; \
+ } \
+ if (value < min) { \
+ min = value; \
+ } \
+ } \
+ \
+ std::cout << "Min="; \
+ PrintValue(min); \
+ \
+ std::cout << ",Max="; \
+ PrintValue(max); \
+ }
+
+DEF_PRINT_COMMON_STATS_INT4(Int4x2)
+DEF_PRINT_COMMON_STATS_INT4(UInt4x2)
+
template
void PrintHalfStats(const T* data, size_t count) {
float min = data[0].ToFloat();
diff --git a/onnxruntime/core/framework/print_tensor_utils.h b/onnxruntime/core/framework/print_tensor_utils.h
index 6bd4e2d3af3fd..b8c50a266b655 100644
--- a/onnxruntime/core/framework/print_tensor_utils.h
+++ b/onnxruntime/core/framework/print_tensor_utils.h
@@ -75,6 +75,29 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t
std::cout << std::endl;
}
+// INT4 - Print snippet of 2D tensor with shape (dim0, dim1)
+#define DEF_PRINT_CPU_TENSOR_SNIPPET_2D_INT4(INT4_TYPE) \
+ template <> \
+ inline void PrintCpuTensorSnippet(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1, \
+ int64_t edge_items) { \
+ for (int64_t i = 0; i < dim0; i++) { \
+ SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \
+ auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ for (int64_t j = 1; j < dim1; j++) { \
+ SKIP_NON_EDGE_ITEMS_LAST_DIM(dim1, j, edge_items); \
+ std::cout << ", "; \
+ indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
+ }
+
+DEF_PRINT_CPU_TENSOR_SNIPPET_2D_INT4(Int4x2)
+DEF_PRINT_CPU_TENSOR_SNIPPET_2D_INT4(UInt4x2)
+
// Print snippet of 3D tensor with shape (dim0, dim1, dim2)
template
void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t dim2, int64_t edge_items) {
@@ -95,6 +118,33 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t
std::cout << std::endl;
}
+// INT4 - Print snippet of 3D tensor with shape (dim0, dim1, dim2)
+#define DEF_PRINT_CPU_TENSOR_SNIPPET_3D_INT4(INT4_TYPE) \
+ template <> \
+ inline void PrintCpuTensorSnippet(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2, \
+ int64_t edge_items) { \
+ for (int64_t i = 0; i < dim0; i++) { \
+ SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \
+ for (int64_t j = 0; j < dim1; j++) { \
+ SKIP_NON_EDGE_ITEMS(dim1, j, edge_items); \
+ auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ for (int64_t k = 1; k < dim2; k++) { \
+ SKIP_NON_EDGE_ITEMS_LAST_DIM(dim2, k, edge_items); \
+ std::cout << ", "; \
+ indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
+ }
+
+DEF_PRINT_CPU_TENSOR_SNIPPET_3D_INT4(Int4x2)
+DEF_PRINT_CPU_TENSOR_SNIPPET_3D_INT4(UInt4x2)
+
// Print 2D tensor
template
void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1) {
@@ -109,6 +159,26 @@ void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1) {
std::cout << std::endl;
}
+// INT4 - Print 2D tensor
+#define DEF_PRINT_CPU_TENSOR_FULL_2D_INT4(INT4_TYPE) \
+ template <> \
+ inline void PrintCpuTensorFull(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1) { \
+ for (int64_t i = 0; i < dim0; i++) { \
+ auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ for (int64_t j = 1; j < dim1; j++) { \
+ std::cout << ", "; \
+ indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
+ }
+
+DEF_PRINT_CPU_TENSOR_FULL_2D_INT4(Int4x2)
+DEF_PRINT_CPU_TENSOR_FULL_2D_INT4(UInt4x2)
+
// Print 3D tensor
template
void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1, int64_t dim2) {
@@ -126,6 +196,29 @@ void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1, int64_t dim
std::cout << std::endl;
}
+// INT4 - Print 3D tensor
+#define DEF_PRINT_CPU_TENSOR_FULL_3D_INT4(INT4_TYPE) \
+ template <> \
+ inline void PrintCpuTensorFull(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2) { \
+ for (int64_t i = 0; i < dim0; i++) { \
+ for (int64_t j = 0; j < dim1; j++) { \
+ auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ for (int64_t k = 1; k < dim2; k++) { \
+ std::cout << ", "; \
+ indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
+ }
+
+DEF_PRINT_CPU_TENSOR_FULL_3D_INT4(Int4x2)
+DEF_PRINT_CPU_TENSOR_FULL_3D_INT4(UInt4x2)
+
template
void PrintCpuTensor(const Tensor& tensor, int threshold = kDefaultSnippetThreshold, int edge_items = kDefaultSnippetEdgeItems) {
const auto& shape = tensor.Shape();
diff --git a/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc b/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc
index 88bb3d70db312..17e26a57f5f3e 100644
--- a/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc
+++ b/onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc
@@ -31,6 +31,34 @@ void VerifyTensorProtoFileData(const PathString& tensor_proto_path, gsl::span(actual_data), expected_data);
}
+
+template
+void VerifyTensorProtoFileDataInt4(const PathString& tensor_proto_path,
+ gsl::span> expected_data,
+ gsl::span shape) {
+ size_t num_elems = 1;
+ for (auto dim_val : shape) {
+ num_elems *= static_cast(dim_val);
+ }
+
+ std::ifstream tensor_proto_stream{tensor_proto_path};
+
+ ONNX_NAMESPACE::TensorProto tensor_proto{};
+ ASSERT_TRUE(tensor_proto.ParseFromIstream(&tensor_proto_stream));
+
+ std::vector> actual_data{};
+ actual_data.resize(expected_data.size());
+ ASSERT_STATUS_OK(utils::UnpackTensor(tensor_proto, Path{}, actual_data.data(), num_elems));
+
+ ASSERT_EQ(actual_data.size(), expected_data.size());
+
+ for (size_t i = 0; i < num_elems; i++) {
+ auto indices = Int4x2Base::GetTensorElemIndices(i);
+ auto actual_val = actual_data[indices.first].GetElem(indices.second);
+ auto expected_val = expected_data[indices.first].GetElem(indices.second);
+ ASSERT_EQ(actual_val, expected_val);
+ }
+}
} // namespace
namespace env_vars = utils::debug_node_inputs_outputs_env_vars;
@@ -72,5 +100,53 @@ TEST(DebugNodeInputsOutputs, BasicFileOutput) {
tester.Run();
}
+// Test dumping input and output INT4 tensors to file.
+TEST(DebugNodeInputsOutputs, FileOutput_Int4) {
+ TemporaryDirectory temp_dir{ORT_TSTR("debug_node_inputs_outputs_utils_test")};
+ ScopedEnvironmentVariables scoped_env_vars{
+ EnvVarMap{
+ {env_vars::kDumpInputData, "1"},
+ {env_vars::kDumpOutputData, "1"},
+ {env_vars::kNameFilter, nullopt},
+ {env_vars::kOpTypeFilter, nullopt},
+ {env_vars::kDumpDataDestination, "files"},
+ {env_vars::kAppendRankToFileName, nullopt},
+ {env_vars::kOutputDir, ToUTF8String(temp_dir.Path())},
+ {env_vars::kDumpingDataToFilesForAllNodesIsOk, "1"},
+ }};
+
+ constexpr int8_t unused_val = 0;
+ const std::vector input_shape({5, 3});
+ const std::vector input_vals = {Int4x2(1, 2), Int4x2(3, 4), Int4x2(5, 6), Int4x2(7, 8),
+ Int4x2(9, 10), Int4x2(11, 12), Int4x2(13, 14), Int4x2(15, unused_val)};
+
+ const std::vector perm = {1, 0};
+ const std::vector expected_shape({3, 5});
+ const std::vector expected_vals = {Int4x2(1, 4), Int4x2(7, 10), Int4x2(13, 2), Int4x2(5, 8),
+ Int4x2(11, 14), Int4x2(3, 6), Int4x2(9, 12), Int4x2(15, unused_val)};
+
+ OpTester tester{"Transpose", 21, kOnnxDomain};
+ tester.AddAttribute("perm", perm);
+ tester.AddInput("x", input_shape, input_vals);
+ tester.AddOutput("y", expected_shape, expected_vals);
+
+ auto verify_file_data =
+ [&temp_dir, &input_vals, &expected_vals, &input_shape, &expected_shape](
+ const std::vector& fetches,
+ const std::string& /*provider_type*/) {
+ ASSERT_EQ(fetches.size(), 1u);
+ // check it contains a tensor
+ fetches[0].Get();
+ VerifyTensorProtoFileDataInt4(temp_dir.Path() + ORT_TSTR("/x.tensorproto"), gsl::make_span(input_vals),
+ gsl::make_span(input_shape));
+ VerifyTensorProtoFileDataInt4(temp_dir.Path() + ORT_TSTR("/y.tensorproto"),
+ gsl::make_span(expected_vals), gsl::make_span(expected_shape));
+ };
+
+ tester.SetCustomOutputVerifier(verify_file_data);
+
+ tester.Run();
+}
+
} // namespace test
} // namespace onnxruntime