Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fs-eire/allow-flexibl…
Browse files Browse the repository at this point in the history
…e-webgpu-backend-selection
  • Loading branch information
fs-eire committed Dec 20, 2023
2 parents a4d7d5e + 98510fb commit ae8fb82
Show file tree
Hide file tree
Showing 97 changed files with 27,035 additions and 1,359 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
name: Download prebuilt ONNX Runtime archive from build.rs
runs-on: ubuntu-latest
env:
ORT_RUST_STRATEGY=download
ORT_RUST_STRATEGY: download
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/rust-toolchain-setup
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/stale.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v8.0.0
- uses: actions/stale@v9.0.0
with:
# Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale
exempt-issue-labels: contributions welcome, feature request, regression
Expand All @@ -29,7 +29,7 @@ jobs:
# Label you want to apply to issues that have been inactive for the amount of time specified by days-before-issue-stale
stale-issue-label: "stale"
# Comment that you want to add to issues that are labeled by the actions/stale action
stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details."
stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details."
# Comment that you want to add to issues that are closed by the actions/stale action
close-issue-message: "This issue has been automatically closed due to inactivity. Please reactivate if further support is needed."
# If you never want this action to label PRs, set this value to -1
Expand Down
12 changes: 12 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
Expand Down Expand Up @@ -1166,6 +1167,17 @@ if (onnxruntime_USE_DNNL)
add_compile_definitions(DNNL_OPENMP)
endif()

set(USE_JBLAS FALSE)
if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
endif()
endif()

# TVM EP
if (onnxruntime_USE_TVM)
if (NOT TARGET tvm)
Expand Down
8 changes: 4 additions & 4 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ if (onnxruntime_BUILD_UNIT_TESTS)
FetchContent_Declare(
googletest
URL ${DEP_URL_googletest}
FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest
URL_HASH SHA1=${DEP_SHA1_googletest}
FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest
)
endif()

Expand Down Expand Up @@ -124,7 +124,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE)
if(protoc_binary_SOURCE_DIR)
message("Use prebuilt protoc")
set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe)
set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
endif()
elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux")
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
Expand All @@ -140,7 +140,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE)
if(protoc_binary_SOURCE_DIR)
message("Use prebuilt protoc")
set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc)
set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
endif()
elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin")
FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal})
Expand Down Expand Up @@ -281,7 +281,7 @@ if ((CPUINFO_SUPPORTED OR onnxruntime_USE_XNNPACK) AND NOT ANDROID)
pytorch_clog
URL ${DEP_URL_pytorch_cpuinfo}
URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo}
SOURCE_SUBDIR deps/clog
SOURCE_SUBDIR deps/clog
)
set(ONNXRUNTIME_CLOG_PROJ pytorch_clog)
set(ONNXRUNTIME_CLOG_TARGET_NAME clog)
Expand Down
16 changes: 14 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ endif()

set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)

function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/jblas_gemm.cpp
)
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
endfunction()

#TODO: set MASM flags properly
function(setup_mlas_source_for_windows)

Expand Down Expand Up @@ -200,7 +209,6 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
)
endif()

else()
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
Expand Down Expand Up @@ -566,7 +574,7 @@ else()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
endif()
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs})
Expand Down Expand Up @@ -604,6 +612,10 @@ else()
target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs})
endif()

if(USE_JBLAS)
add_jblas()
endif()

foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
Expand Down
95 changes: 47 additions & 48 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1373,56 +1373,55 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_link_libraries(compare_two_sessions PRIVATE ${GETOPT_LIB_WIDE} tdh Advapi32)
endif()

file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/mlas/unittest/*.h"
"${TEST_SRC_DIR}/mlas/unittest/*.cpp"
)
onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src})
if(MSVC)
target_compile_options(onnxruntime_mlas_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd26409>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26409>")
target_compile_options(onnxruntime_mlas_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
target_compile_options(onnxruntime_mlas_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd6326>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd6326>")
target_compile_options(onnxruntime_mlas_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd26426>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26426>")
endif()
if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
set_target_properties(onnxruntime_mlas_test PROPERTIES
XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO"
if(NOT onnxruntime_target_platform STREQUAL "ARM64EC")
file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/mlas/unittest/*.h"
"${TEST_SRC_DIR}/mlas/unittest/*.cpp"
)
endif()
target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}
${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo)
endif()
if(NOT WIN32)
target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS})
endif()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs})
endif()

if(WIN32)
target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32)
endif()
if (onnxruntime_LINK_LIBATOMIC)
target_link_libraries(onnxruntime_mlas_test PRIVATE atomic)
endif()
target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads)

set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest")
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1")
else()
set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1")
onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src})
if(MSVC)
target_compile_options(onnxruntime_mlas_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd26409>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26409>")
target_compile_options(onnxruntime_mlas_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
target_compile_options(onnxruntime_mlas_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd6326>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd6326>")
target_compile_options(onnxruntime_mlas_test PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd26426>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26426>")
endif()
endif()

if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
set_target_properties(onnxruntime_mlas_test PROPERTIES
XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO"
)
endif()
target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}
${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo)
endif()
if(NOT WIN32)
target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS})
endif()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs})
endif()
if(WIN32)
target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32)
endif()
if (onnxruntime_LINK_LIBATOMIC)
target_link_libraries(onnxruntime_mlas_test PRIVATE atomic)
endif()
target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads)
set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest")
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1")
else()
set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1")
endif()
endif()
endif()
# Training API Tests
# Disabling training_api_test_trainer. CXXOPT generates a ton of warnings because of which nuget pipeline is failing.
# TODO(askhade): Fix the warnings.
Expand Down
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2824,6 +2824,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>size of each input feature</dd>
<dt><tt>N</tt> : int (required)</dt>
<dd>size of each output feature</dd>
<dt><tt>accuracy_level</tt> : int</dt>
<dd>The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.</dd>
<dt><tt>bits</tt> : int (required)</dt>
<dd>number of bits used for weight quantization (default 4)</dd>
<dt><tt>block_size</tt> : int (required)</dt>
Expand Down
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ Do not modify directly.*
|Crop|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|CumSum|*in* x:**T**<br> *in* axis:**T2**<br> *out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)|
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)|
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *in* axis:**tensor(int64)**<br> *out* output:**T1**<br><br>or<br><br>*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *in* axis:**tensor(int64)**<br> *out* output:**T1**<br><br>or<br><br>*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|20+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|||[17, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[1, 10]|**T** = tensor(double), tensor(float)|
Expand Down
2 changes: 1 addition & 1 deletion js/react_native/app.plugin.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const withOrt = (config) => {
config = configPlugin.withDangerousMod(config, [
'ios',
(config) => {
const podFilePath = path.join(config.modRequest.platformProjectRoot, 'PodFile');
const podFilePath = path.join(config.modRequest.platformProjectRoot, 'Podfile');
const contents = fs.readFileSync(podFilePath, {encoding: 'utf-8'});
const updatedContents =
generateCode
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -772,14 +772,14 @@ class ShaderHelperImpl implements ShaderHelper {
const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1;
const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>` :
`@builtin(local_invocation_index) local_index : u32,
`@builtin(local_invocation_index) local_idx : u32,
@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
const globalIdxDefinition = is1DimensionDispatch ?
'let global_idx = global_id.x;' :
'let global_idx = global_id.x; let local_idx = local_id.x;' :
`let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${
workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`;
workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`;

return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})
fn main(${paramList}) {
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let m = global_id.x / N;
let n = global_id.x % N;
let m = global_idx / N;
let n = global_idx % N;
var value = ${dataType}(0);
for (var k: u32 = 0u; k<${K}u; k++) {
Expand All @@ -107,7 +107,7 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
${calculateAlpha}
${calculateC}
output[global_id.x] = value;
output[global_idx] = value;
}`;
return {
Expand Down
Loading

0 comments on commit ae8fb82

Please sign in to comment.