Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into r1
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jul 17, 2024
2 parents 86f297f + 6c7562b commit 4430e9b
Show file tree
Hide file tree
Showing 45 changed files with 447 additions and 87 deletions.
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ function(onnxruntime_set_compile_flags target_name)
# Enable warning
target_compile_options(${target_name} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options -Wall>" "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:-Wall>")
target_compile_options(${target_name} PRIVATE "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:-Wextra>")
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "IBMClang")
#external/protobuf/src/google/protobuf/arena.h:445:18: error: unused parameter 'p'
target_compile_options(${target_name} PRIVATE "-Wno-unused-parameter")
endif()
Expand Down
3 changes: 3 additions & 0 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
set(gtest_disable_pthreads ON)
endif()
if (${CMAKE_SYSTEM_NAME} MATCHES "AIX")
set(gtest_disable_pthreads ON CACHE BOOL "gtest_disable_pthreads" FORCE)
endif()
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
if (IOS OR ANDROID)
# on mobile platforms the absl flags class dumps the flag names (assumably for binary size), which breaks passing
Expand Down
45 changes: 35 additions & 10 deletions cmake/onnxruntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ foreach(f ${ONNXRUNTIME_PROVIDER_NAMES})
list(APPEND SYMBOL_FILES "${ONNXRUNTIME_ROOT}/core/providers/${f}/symbols.txt")
endforeach()

if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
add_custom_command(OUTPUT ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c
COMMAND ${Python_EXECUTABLE} "${REPO_ROOT}/tools/ci_build/gen_def.py"
--version_file "${ONNXRUNTIME_ROOT}/../VERSION_NUMBER" --src_root "${ONNXRUNTIME_ROOT}"
Expand All @@ -66,6 +67,7 @@ add_custom_command(OUTPUT ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_s
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})

add_custom_target(onnxruntime_generate_def ALL DEPENDS ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c)
endif()
if(WIN32)
onnxruntime_add_shared_library(onnxruntime
${SYMBOL_FILE}
Expand Down Expand Up @@ -99,13 +101,21 @@ elseif(onnxruntime_BUILD_APPLE_FRAMEWORK)
# Note: The PUBLIC_HEADER and VERSION properties for the 'onnxruntime' target will be set later in this file.
)
else()
onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c)
if(${CMAKE_SYSTEM_NAME} MATCHES "AIX")
onnxruntime_add_shared_library(onnxruntime ${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc)
else()
onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c )
endif()
if (onnxruntime_USE_CUDA)
set_property(TARGET onnxruntime APPEND_STRING PROPERTY LINK_FLAGS " -Xlinker -rpath=\\$ORIGIN")
endif()
endif()

add_dependencies(onnxruntime onnxruntime_generate_def ${onnxruntime_EXTERNAL_DEPENDENCIES})
if(${CMAKE_SYSTEM_NAME} MATCHES "AIX")
add_dependencies(onnxruntime ${onnxruntime_EXTERNAL_DEPENDENCIES})
else()
add_dependencies(onnxruntime onnxruntime_generate_def ${onnxruntime_EXTERNAL_DEPENDENCIES})
endif()
target_include_directories(onnxruntime PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime>")

target_compile_definitions(onnxruntime PRIVATE VER_MAJOR=${VERSION_MAJOR_PART})
Expand All @@ -118,7 +128,7 @@ target_compile_definitions(onnxruntime PRIVATE FILE_NAME=\"onnxruntime.dll\")
if(UNIX)
if (APPLE)
set(ONNXRUNTIME_SO_LINK_FLAG " -Xlinker -dead_strip")
else()
elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
set(ONNXRUNTIME_SO_LINK_FLAG " -Xlinker --version-script=${SYMBOL_FILE} -Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
endif()
else()
Expand All @@ -138,7 +148,7 @@ if (NOT WIN32)
else()
set_target_properties(onnxruntime PROPERTIES INSTALL_RPATH "@loader_path")
endif()
elseif (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
elseif (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-rpath='$ORIGIN'")
endif()
endif()
Expand Down Expand Up @@ -206,6 +216,10 @@ set(onnxruntime_INTERNAL_LIBRARIES
onnxruntime_flatbuffers
)

if (${CMAKE_SYSTEM_NAME} MATCHES "AIX")
list(APPEND onnxruntime_INTERNAL_LIBRARIES iconv)
endif()

if (onnxruntime_USE_EXTENSIONS)
list(APPEND onnxruntime_INTERNAL_LIBRARIES
onnxruntime_extensions
Expand All @@ -222,12 +236,23 @@ target_link_libraries(onnxruntime PRIVATE
)

set_property(TARGET onnxruntime APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_SO_LINK_FLAG} ${onnxruntime_DELAYLOAD_FLAGS})
set_target_properties(onnxruntime PROPERTIES
PUBLIC_HEADER "${ONNXRUNTIME_PUBLIC_HEADERS}"
LINK_DEPENDS ${SYMBOL_FILE}
VERSION ${ORT_VERSION}
FOLDER "ONNXRuntime"
)



if(${CMAKE_SYSTEM_NAME} MATCHES "AIX")
set_target_properties(onnxruntime PROPERTIES
PUBLIC_HEADER "${ONNXRUNTIME_PUBLIC_HEADERS}"
VERSION ${ORT_VERSION}
SOVERSION 1
FOLDER "ONNXRuntime")
else()
set_target_properties(onnxruntime PROPERTIES
PUBLIC_HEADER "${ONNXRUNTIME_PUBLIC_HEADERS}"
LINK_DEPENDS ${SYMBOL_FILE}
VERSION ${ORT_VERSION}
SOVERSION 1
FOLDER "ONNXRuntime")
endif()

install(TARGETS onnxruntime
EXPORT ${PROJECT_NAME}Targets
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_framework.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ add_dependencies(onnxruntime_framework ${onnxruntime_EXTERNAL_DEPENDENCIES})
# For the shared onnxruntime library, this is set in onnxruntime.cmake through CMAKE_SHARED_LINKER_FLAGS
# But our test files don't use the shared library so this must be set for them.
# For Win32 it generates an absolute path for shared providers based on the location of the executable/onnxruntime.dll
if (UNIX AND NOT APPLE AND NOT onnxruntime_MINIMAL_BUILD AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
if (UNIX AND NOT APPLE AND NOT onnxruntime_MINIMAL_BUILD AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-rpath='$ORIGIN'")
endif()

Expand Down
14 changes: 13 additions & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,24 @@ else()
)
if(COMPILES_P10)
check_cxx_source_compiles("
#ifdef _AIX
#define POWER_10 0x40000
#define POWER_10_ANDUP (POWER_10)
#include <sys/systemcfg.h>
#define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP)
int main() {
bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31);
return 0;
}
#else
#include <sys/auxv.h>
int main() {
unsigned long hwcap2 = getauxval(AT_HWCAP2);
bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1));
return 0;
}"
}
}
#endif"
HAS_P10_RUNTIME
)
if (HAS_P10_RUNTIME)
Expand Down
4 changes: 3 additions & 1 deletion cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD
if(APPLE)
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/shared/exported_symbols.lst")
elseif(UNIX)
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds -Xlinker --gc-sections")
if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/shared/version_script.lds -Xlinker --gc-sections")
endif()
elseif(WIN32)
set_property(TARGET onnxruntime_providers_shared APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/shared/symbols.def")
set(ONNXRUNTIME_PROVIDERS_SHARED onnxruntime_providers_shared)
Expand Down
25 changes: 20 additions & 5 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
list(APPEND onnxruntime_perf_test_libs ${android_shared_libs})
endif()
if (${CMAKE_SYSTEM_NAME} MATCHES "AIX")
list(APPEND onnxruntime_perf_test_libs onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 gtest absl_failure_signal_handler absl_examine_stack absl_flags_parse absl_flags_usage absl_flags_usage_internal)
endif()
target_link_libraries(onnxruntime_perf_test PRIVATE ${onnxruntime_perf_test_libs} Threads::Threads)
if(WIN32)
target_link_libraries(onnxruntime_perf_test PRIVATE debug dbghelp advapi32)
Expand Down Expand Up @@ -1275,6 +1278,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
list(APPEND onnxruntime_shared_lib_test_LIBS ${android_shared_libs})
endif()

if (${CMAKE_SYSTEM_NAME} MATCHES "AIX")
list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2)
endif()

AddTest(DYN
TARGET onnxruntime_shared_lib_test
SOURCES ${onnxruntime_shared_lib_test_SRC} ${onnxruntime_unittest_main_src}
Expand Down Expand Up @@ -1510,7 +1517,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
if(UNIX)
if (APPLE)
set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-Xlinker -dead_strip")
else()
elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.lds -Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
endif()
else()
Expand Down Expand Up @@ -1574,6 +1581,9 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
if (onnxruntime_USE_TENSORRT)
list(APPEND onnxruntime_customopregistration_test_LIBS ${TENSORRT_LIBRARY_INFER})
endif()
if (${CMAKE_SYSTEM_NAME} MATCHES "AIX")
list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp)
endif()
AddTest(DYN
TARGET onnxruntime_customopregistration_test
SOURCES ${onnxruntime_customopregistration_test_SRC} ${onnxruntime_unittest_main_src}
Expand Down Expand Up @@ -1608,7 +1618,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI
if(UNIX)
if (APPLE)
set(ONNXRUNTIME_CUSTOM_OP_INVALID_LIB_LINK_FLAG "-Xlinker -dead_strip")
else()
elseif (NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
string(CONCAT ONNXRUNTIME_CUSTOM_OP_INVALID_LIB_LINK_FLAG
"-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_invalid_library/custom_op_library.lds "
"-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
Expand Down Expand Up @@ -1639,7 +1649,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI
if(UNIX)
if (APPLE)
set(ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG "-Xlinker -dead_strip")
else()
elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
string(CONCAT ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG
"-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op_lib.lds "
"-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
Expand Down Expand Up @@ -1671,7 +1681,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI
if(UNIX)
if (APPLE)
set(ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG "-Xlinker -dead_strip")
else()
elseif(NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
string(CONCAT ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG
"-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.lds "
"-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
Expand All @@ -1690,6 +1700,9 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten"
${ONNXRUNTIME_LOGGING_APIS_TEST_SRC_DIR}/test_logging_apis.cc)

set(onnxruntime_logging_apis_test_LIBS onnxruntime_common onnxruntime_test_utils)
if (${CMAKE_SYSTEM_NAME} MATCHES "AIX")
list(APPEND onnxruntime_logging_apis_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_framework onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp)
endif()

if(NOT WIN32)
list(APPEND onnxruntime_logging_apis_test_LIBS nsync::nsync_cpp ${CMAKE_DL_LIBS})
Expand Down Expand Up @@ -1753,7 +1766,9 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD
if(APPLE)
set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst")
elseif(UNIX)
set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\\$ORIGIN")
if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "AIX")
set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\\$ORIGIN")
endif()
elseif(WIN32)
set_property(TARGET test_execution_provider APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${REPO_ROOT}/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def")
else()
Expand Down
18 changes: 18 additions & 0 deletions cmake/patches/flatbuffers/flatbuffers.patch
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,21 @@ index 3987eac9..5e5462f1 100644
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS} -Wno-error=stringop-overflow")
endif()
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
diff --git a/include/flatbuffers/flatbuffers.h b/include/flatbuffers/flatbuffers.h
index bc828a31..3d3effe8 100644
--- a/include/flatbuffers/flatbuffers.h
+++ b/include/flatbuffers/flatbuffers.h
@@ -213,7 +213,12 @@ inline const char * const *ElementaryTypeNames() {
// We're explicitly defining the signedness since the signedness of integer
// bitfields is otherwise implementation-defined and causes warnings on older
// GCC compilers.
-struct TypeCode {
+
+struct
+#if defined(_AIX) && defined(__clang__)
+__attribute__((packed))
+#endif
+TypeCode {
// ElementaryType
unsigned short base_type : 4;
// Either vector (in table) or array (in struct)
62 changes: 53 additions & 9 deletions onnxruntime/contrib_ops/cpu/murmur_hash3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
/* Modifications Copyright (c) Microsoft. */

#include "contrib_ops/cpu/murmur_hash3.h"
#include <memory>
#include <utility>

// Platform-specific functions and macros

Expand Down Expand Up @@ -60,11 +62,31 @@ inline uint64_t rotl64(uint64_t x, int8_t r) {
// handle aligned reads, do the conversion here

FORCE_INLINE uint32_t getblock(const uint32_t* p, int i) {
return p[i];
if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) {
return p[i];
} else {
const uint8_t* c = (const uint8_t*)&p[i];
return (uint32_t)c[0] |
(uint32_t)c[1] << 8 |
(uint32_t)c[2] << 16 |
(uint32_t)c[3] << 24;
}
}

FORCE_INLINE uint64_t getblock(const uint64_t* p, int i) {
return p[i];
if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) {
return p[i];
} else {
const uint8_t* c = (const uint8_t*)&p[i];
return (uint64_t)c[0] |
(uint64_t)c[1] << 8 |
(uint64_t)c[2] << 16 |
(uint64_t)c[3] << 24 |
(uint64_t)c[4] << 32 |
(uint64_t)c[5] << 40 |
(uint64_t)c[6] << 48 |
(uint64_t)c[7] << 56;
}
}

//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -204,13 +226,35 @@ Status MurmurHash3::Compute(OpKernelContext* ctx) const {
int input_num_bytes = static_cast<int>(input_element_bytes);
ORT_ENFORCE(input_num_bytes % 4 == 0);
const auto input_end = input + input_count * input_num_bytes;
while (input != input_end) {
MurmurHash3_x86_32(input,
input_num_bytes,
seed_,
output);
input += input_num_bytes;
++output;

if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) {
while (input != input_end) {
MurmurHash3_x86_32(input,
input_num_bytes,
seed_,
output);
input += input_num_bytes;
++output;
}
} else {
// Big endian platform require byte swapping.
auto raw_data = std::make_unique<char[]>(input_num_bytes);
char* raw_data_ptr = raw_data.get();
while (input != input_end) {
memcpy(raw_data_ptr, input, input_num_bytes);
char* start_byte = raw_data_ptr;
char* end_byte = start_byte + input_num_bytes - 1;
for (size_t count = 0; count < static_cast<size_t>(input_num_bytes / 2); ++count) {
std::swap(*start_byte++, *end_byte--);
}

MurmurHash3_x86_32(raw_data_ptr,
input_num_bytes,
seed_,
output);
input += input_num_bytes;
++output;
}
}
}
return Status::OK();
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ void Dequantize4BitsKernelReOrder(
}
T* output_i = output + out_y * out_cols + out_x;
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
if constexpr (onnxruntime::endian::native == onnxruntime::endian::big) {
const uint8_t* c = (const uint8_t*)(&quant_value);
quant_value = (uint32_t)c[0] |
(uint32_t)c[1] << 8 |
(uint32_t)c[2] << 16 |
(uint32_t)c[3] << 24;
}
const int remain_x = std::min(8, out_cols - out_x);
const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx_x * 8) & (block_size - 1));
for (int i = 0; i < remain_x; i++) {
Expand Down
Loading

0 comments on commit 4430e9b

Please sign in to comment.