diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ed9043f2adc4a..8453da19ce3a6 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -324,15 +324,27 @@ if (onnxruntime_USE_ROCM) endif() # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/8eb21488fdcdb8b0e6fa2e46179b5fa6c42e75af/cmake/public/LoadHIP.cmake#L153-L173 - file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) - if (ROCM_VERSION_DEV_MATCH) + # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake + # with modification + if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version-dev") + file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW) + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + endif() + + if (ROCM_VERSION_MATCH) set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + else() + message(FATAL_ERROR "Cannot determine ROCm version string") endif() message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version-dev ****\n") message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") @@ -1400,6 +1412,10 @@ endif() if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_CUDA_STANDARD 17) + if(onnxruntime_CUDA_HOME) + file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) + endif() + find_package(CUDAToolkit REQUIRED) if(onnxruntime_CUDNN_HOME) file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) endif() diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 8161ea574b8cc..d3f9256105127 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -205,7 +205,7 @@ endif() macro(check_nvcc_compiler_flag _FLAG _RESULT) - execute_process(COMMAND ${onnxruntime_CUDA_HOME}/bin/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) + execute_process(COMMAND ${CUDAToolkit_BIN_DIR}/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) message("NVCC_ERROR = ${NVCC_ERROR}") message("NVCC_OUT = ${NVCC_OUT}") if ("${NVCC_OUT}" MATCHES "0") diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 22d12b128dc1f..09d57164b4ee1 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -556,16 +556,15 @@ message("Finished fetching external dependencies") set(onnxruntime_LINK_DIRS ) if (onnxruntime_USE_CUDA) #TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same + find_package(CUDAToolkit REQUIRED) if (WIN32) if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/x64/lib64) else() if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/lib64) endif() endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 9887d615c92d7..0f6d48bdb6ec8 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -178,15 +178,16 @@ add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if(onnxruntime_CUDA_MINIMAL) target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) - target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart) else() - target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart + ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) if(onnxruntime_CUDNN_HOME) target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) endif() endif() - + if (onnxruntime_USE_TRITON_KERNEL) # compile triton kernel, generate .a and .h files include(onnxruntime_compile_triton_kernel.cmake) @@ -196,25 +197,24 @@ target_include_directories(${target} PRIVATE ${triton_kernel_header_dir}) target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) # lib cuda needed by cuLaunchKernel - target_link_libraries(${target} PRIVATE cuda) + target_link_libraries(${target} PRIVATE CUDA::cuda_driver) endif() include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) - target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling - target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64) - target_link_libraries(${target} PRIVATE cupti) + target_link_libraries(${target} PRIVATE CUDA::cupti) endif() - if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32) - target_link_libraries(${target} PRIVATE nvToolsExt) + if (onnxruntime_ENABLE_NVTX_PROFILE) + target_link_libraries(${target} PRIVATE CUDA::nvtx3) endif() if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 686a993de3a4a..15ffc29e79ff4 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -8,7 +8,7 @@ set(BUILD_LIBRARY_ONLY 1) add_definitions("-DONNX_ML=1") add_definitions("-DONNX_NAMESPACE=onnx") - set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + set(CUDA_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS}) set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(PROTOBUF_LIBRARY ${PROTOBUF_LIB}) @@ -58,7 +58,7 @@ URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt} ) if (NOT CUDA_INCLUDE_DIR) - set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build + set(CUDA_INCLUDE_DIR ${CUDAToolkit_INCLUDE_DIRS}) # onnx-tensorrt repo needs this variable to build endif() # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. @@ -102,11 +102,12 @@ onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart) else() - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) endif() - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) if(onnxruntime_CUDNN_HOME) target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 3f20787e87425..23c6e5e430875 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -282,10 +282,7 @@ if (WIN32) get_filename_component(CUDNN_DLL_NAME ${CUDNN_DLL_PATH} NAME_WE) string(REPLACE "cudnn64_" "" CUDNN_VERSION "${CUDNN_DLL_NAME}") if(NOT onnxruntime_CUDA_VERSION) - message("Reading json file ${onnxruntime_CUDA_HOME}/version.json") - set(CUDA_SDK_JSON_FILE_PATH "${onnxruntime_CUDA_HOME}/version.json") - file(READ ${CUDA_SDK_JSON_FILE_PATH} CUDA_SDK_JSON_CONTENT) - string(JSON onnxruntime_CUDA_VERSION GET ${CUDA_SDK_JSON_CONTENT} "cuda" "version") + set(onnxruntime_CUDA_VERSION ${CUDAToolkit_VERSION}) message("onnxruntime_CUDA_VERSION=${onnxruntime_CUDA_VERSION}") endif() file(APPEND "${VERSION_INFO_FILE}" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 3ed695327c183..88f662075e177 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -67,7 +67,7 @@ function(AddTest) if(onnxruntime_USE_CUDA) #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, # otherwise it will impact when CUDA DLLs can be unloaded. - target_link_libraries(${_UT_TARGET} PRIVATE cudart) + target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart) endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -1268,7 +1268,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart) endif() if (onnxruntime_USE_ROCM) list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 6299c26159400..73a47d1a4f937 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -36,6 +36,7 @@ export declare namespace Env { /** * set or get a boolean value indicating whether to enable trace. * + * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. * @defaultValue `false` */ trace?: boolean; @@ -167,6 +168,7 @@ export interface Env { * @defaultValue `'warning'` */ logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal'; + /** * Indicate whether run in debug mode. * @@ -174,6 +176,13 @@ export interface Env { */ debug?: boolean; + /** + * set or get a boolean value indicating whether to enable trace. + * + * @defaultValue `false` + */ + trace?: boolean; + /** * Get version of the current package. */ diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts index 404f7ef8089af..7e0487b350198 100644 --- a/js/common/lib/trace.ts +++ b/js/common/lib/trace.ts @@ -4,7 +4,7 @@ import {env} from './env-impl.js'; export const TRACE = (deviceType: string, label: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } // eslint-disable-next-line no-console @@ -30,14 +30,14 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => { }; export const TRACE_FUNC_BEGIN = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('BEGIN', extraMsg); }; export const TRACE_FUNC_END = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('END', extraMsg); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 3e3a191ec3ead..27c5566ab9fed 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -710,7 +710,8 @@ export class WebGpuBackend { } setQueryType(): void { this.queryType = 'none'; - if (this.env.webgpu.profiling?.mode === 'default' || this.env.wasm.trace) { + if (this.env.webgpu.profiling?.mode === 'default' || + (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) { if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) { this.queryType = 'inside-passes'; } else if (this.device.features.has('timestamp-query')) { diff --git a/onnxruntime/core/providers/cuda/nvtx_profile.cc b/onnxruntime/core/providers/cuda/nvtx_profile.cc index 6c7c594066b86..867e7c1f24584 100644 --- a/onnxruntime/core/providers/cuda/nvtx_profile.cc +++ b/onnxruntime/core/providers/cuda/nvtx_profile.cc @@ -4,13 +4,8 @@ #ifdef ENABLE_NVTX_PROFILE #include "nvtx_profile.h" #include "core/common/common.h" -#if defined(_WIN32) || defined(WIN32) || defined(__CYGWIN__) || defined(__MINGW32__) || defined(__BORLANDC__) #include #include -#else -#include -#include -#endif namespace onnxruntime { namespace profile { diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d94729e60d029..d7892fe02c1ba 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -195,7 +195,7 @@ static const InlinedHashMap op_map = { {"LessOrEqual", {"lesserOrEqual", false}}, {"Log", {"log", false}}, {"LpPool", {"l2Pool2d", false}}, - {"MatMul", {"matmul", false}}, + {"MatMul", {"matmul", true}}, {"MatMulInteger", {"matmulInteger", false}}, {"Max", {"max", true}}, {"MaxPool", {"maxPool2d", true}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 4bf991a1b0105..d5f84f853f7de 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -29,7 +29,7 @@ class GemmOpBuilder : public BaseOpBuilder { // Add operator related. Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { + const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C @@ -38,7 +38,17 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); emscripten::val output = emscripten::val::object(); if (op_type == "MatMul") { - output = model_builder.GetBuilder().call("matmul", a, b); + std::vector a_shape; + if (!GetShape(*input_defs[a_idx], a_shape, logger)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of A."); + } + // The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case. + // TODO: Remove this workaround when it is fixed in Chromium. + if (model_builder.GetWebnnDeviceType() == WebnnDeviceType::CPU && a_shape.size() == 2) { + output = model_builder.GetBuilder().call("gemm", a, b); + } else { + output = model_builder.GetBuilder().call("matmul", a, b); + } } else if (op_type == "MatMulInteger") { emscripten::val a_zero_point = emscripten::val::null(); emscripten::val b_zero_point = emscripten::val::null(); diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py index 9ebf400498e0e..fbf954febdda4 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py @@ -122,6 +122,11 @@ def fuse( self.nodes_to_remove.extend(subgraph_nodes) fused_node = onnx.helper.make_node( - self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1 + self.fused_op_type, + name=self.create_unique_node_name(), + inputs=[subgraph_input], + outputs=[subgraph_output], + p=2, + axis=-1, ) self.nodes_to_add.append(fused_node) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index becbaceab184e..b1c114fe1f9fd 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -44,6 +44,17 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: if fusion_layernorm.apply(): modified = True + # Make sure all nodes have a name. + unnamed_node_prefix = "qnn_preproc_node_" + available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1 + for node in onnx_model.model.graph.node: + if node.op_type != "Constant" and not node.name: + new_node_name = f"{unnamed_node_prefix}{available_suffix!s}" + available_suffix += 1 + node.name = new_node_name + modified = True + logging.warning(f"Node of type {node.op_type} does not have a name. Renamed to {new_node_name}.") + if modified: onnx_model.topological_sort() onnx.save_model(model, model_output) diff --git a/onnxruntime/python/tools/quantization/fusions/fusion.py b/onnxruntime/python/tools/quantization/fusions/fusion.py index b54b421226f1a..4bdc5c26cc946 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion.py @@ -24,6 +24,9 @@ def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str): self.nodes_to_remove: list = [] self.nodes_to_add: list = [] + self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_" + self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops. + def fuse( self, node: onnx.NodeProto, @@ -57,6 +60,18 @@ def apply(self) -> bool: return graph_updated + def create_unique_node_name(self): + prefix = self._new_node_name_prefix + + if self._new_node_name_suffix is None: + largest_suffix: int = self.model.get_largest_node_name_suffix(prefix) + self._new_node_name_suffix = largest_suffix + 1 + + new_name = f"{prefix}{self._new_node_name_suffix!s}" + self._new_node_name_suffix += 1 + + return new_name + @staticmethod def is_safe_to_fuse_nodes( nodes_to_remove: list[onnx.NodeProto], diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py index a20d6dbffd7a7..42c4a11833641 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py @@ -112,7 +112,9 @@ def fuse_1( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True @@ -173,11 +175,9 @@ def fuse_2( if not self.has_constant_input(sqrt_node, 2.0): return False - root_node = self.model.get_parent(div, 0, output_name_to_node) - if root_node is None: - return False + subgraph_input = div.input[0] - if root_node.output[0] not in mul.input: + if subgraph_input not in mul.input: return False subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul] @@ -188,7 +188,9 @@ def fuse_2( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True @@ -239,9 +241,8 @@ def fuse_3( if i < 0: return False - root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node) - if root_node is None: - return False + root_input_index = 1 - i + subgraph_input = first_mul.input[root_input_index] if mul_half.output[0] not in input_name_to_nodes: return False @@ -250,7 +251,7 @@ def fuse_3( return False last_mul = children[0] - if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]): + if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input): return False subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul] @@ -263,7 +264,9 @@ def fuse_3( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py index d7fb89236d3d2..7d58c1c180822 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py @@ -127,6 +127,7 @@ def fuse( normalize_node = onnx.helper.make_node( "LayerNormalization", + name=self.create_unique_node_name(), inputs=[reduce_mean_node.input[0], weight_input, bias_input], outputs=[last_add_node.output[0]], ) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 4591c9c950e6e..46d245d353a07 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -283,6 +283,23 @@ def find_node_by_name(self, node_name, new_nodes_list, graph): node = find_by_name(node_name, graph_nodes_list) return node + def get_largest_node_name_suffix(self, node_name_prefix): + """ + Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`. + Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3. + """ + suffix = -1 + + for node in self.model.graph.node: + if node.name and node.name.startswith(node_name_prefix): + try: + index = int(node.name[len(node_name_prefix) :]) + suffix = max(index, suffix) + except ValueError: + continue + + return suffix + def find_nodes_by_initializer(self, graph, initializer): """ Find all nodes with given initializer as an input. diff --git a/onnxruntime/test/python/quantization/test_fusions.py b/onnxruntime/test/python/quantization/test_fusions.py new file mode 100644 index 0000000000000..bea110e566fb9 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_fusions.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import unittest + +import numpy as np +import onnx + +import onnxruntime +from onnxruntime.quantization.execution_providers.qnn.fusion_lpnorm import FusionLpNormalization +from onnxruntime.quantization.fusions import FusionGelu, FusionLayerNormalization +from onnxruntime.quantization.onnx_model import ONNXModel + + +class TestFusions(unittest.TestCase): + def check_fused_model_correctness(self, orig_model, fused_model, inputs, rtol=1e-7, atol=0): + """ + Checks that the output of the fused model matches the output of the original model. + """ + orig_session = onnxruntime.InferenceSession(orig_model.SerializeToString(), providers=["CPUExecutionProvider"]) + orig_results = orig_session.run(None, inputs) + + fused_session = onnxruntime.InferenceSession( + fused_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + fused_results = fused_session.run([], inputs) + + self.assertEqual(len(orig_results), len(fused_results), "Number of outputs for fused model differs") + for idx, expected_output in enumerate(orig_results): + actual_output = fused_results[idx] + np.testing.assert_allclose( + expected_output, + actual_output, + rtol=rtol, + atol=atol, + err_msg=f"Fused model output {idx} differs", + ) + + def build_erf_sequence_1_model(self, shape): + """ + Erf sequence that fuses into Gelu: + +-------Mul(0.5)---------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul --> + (B=1.4142...) (1) + + This method builds 2 of these Erf sequences: + + [root] -> ERF_SEQUENCE1 -> ERF_SEQUENCE2 -> output + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + # First Erf sequence + mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["mul0_out"]) + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul1_node = onnx.helper.make_node("Mul", ["add_out", "mul0_out"], ["seq1_output"]) + + # Second Erf sequence + mul0_node_dup = onnx.helper.make_node("Mul", ["seq1_output", "half_const"], ["mul0_out_dup"]) + div_node_dup = onnx.helper.make_node("Div", ["seq1_output", "root2_const"], ["div_out_dup"]) + erf_node_dup = onnx.helper.make_node("Erf", ["div_out_dup"], ["erf_out_dup"]) + add_node_dup = onnx.helper.make_node("Add", ["erf_out_dup", "one_const"], ["add_out_dup"]) + mul1_node_dup = onnx.helper.make_node("Mul", ["add_out_dup", "mul0_out_dup"], ["output"]) + + graph = onnx.helper.make_graph( + [ + mul0_node, + div_node, + erf_node, + add_node, + mul1_node, + mul0_node_dup, + div_node_dup, + erf_node_dup, + add_node_dup, + mul1_node_dup, + ], + "two_erf_sequences", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_2_model(self, shape): + """ + +------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul --> + (B=1.4142...) (1) (0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul0_node = onnx.helper.make_node("Mul", ["add_out", "root"], ["mul0_out"]) + mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "half_const"], ["output"]) + + graph = onnx.helper.make_graph( + [div_node, erf_node, add_node, mul0_node, mul1_node], + "erf_sequence_2", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_3_model(self, shape): + """ + +------------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul + (B=1.4142...) (A=1) (A=0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul0_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul0_out"]) + mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "root"], ["output"]) + + graph = onnx.helper.make_graph( + [div_node, erf_node, add_node, mul0_node, mul1_node], + "erf_sequence_3", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_4_model(self, shape): + """ + +----------------------------------------------+ + | | + | v + [root] --> Mul -----> Erf --> Add --> Mul -->Mul + (A=0.7071067690849304) (B=1) (B=0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + frac_const = onnx.numpy_helper.from_array(np.array(0.7071067690849304, dtype=np.float32), "frac_const") + + mul0_node = onnx.helper.make_node("Mul", ["root", "frac_const"], ["mul0_out"]) + erf_node = onnx.helper.make_node("Erf", ["mul0_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul1_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul1_out"]) + mul2_node = onnx.helper.make_node("Mul", ["mul1_out", "root"], ["output"]) + + graph = onnx.helper.make_graph( + [mul0_node, erf_node, add_node, mul1_node, mul2_node], + "erf_sequence_4", + [root_inp], + [output], + initializer=[one_const, half_const, frac_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_reduce_mean_sequence_model(self, shape, scale_val, bias_val, axis=-1): + """ + +----------------------+ + | | + | v + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ ^ ^ + | | | | + +-------------------------------------------------+ [Scale] [Bias] + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const") + bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") + eps_const = onnx.numpy_helper.from_array(np.array(1.0e-8, dtype=np.float32), "eps_const") + + rm0_node = onnx.helper.make_node("ReduceMean", ["root", "axes_const"], ["rm0_out"]) + sub_node = onnx.helper.make_node("Sub", ["root", "rm0_out"], ["sub_out"]) + pow_node = onnx.helper.make_node("Pow", ["sub_out", "two_const"], ["pow_out"]) + rm1_node = onnx.helper.make_node("ReduceMean", ["pow_out", "axes_const"], ["rm1_out"]) + add0_node = onnx.helper.make_node("Add", ["rm1_out", "eps_const"], ["add0_out"]) + sqrt_node = onnx.helper.make_node("Sqrt", ["add0_out"], ["sqrt_out"]) + div_node = onnx.helper.make_node("Div", ["sub_out", "sqrt_out"], ["div_out"]) + mul_node = onnx.helper.make_node("Mul", ["div_out", "scale_const"], ["mul_out"]) + add1_node = onnx.helper.make_node("Add", ["mul_out", "bias_const"], ["output"]) + + graph = onnx.helper.make_graph( + [rm0_node, sub_node, pow_node, rm1_node, add0_node, sqrt_node, div_node, mul_node, add1_node], + "reduce_mean_sequence", + [root_inp], + [output], + initializer=[scale_const, bias_const, axes_const, two_const, eps_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_reduce_l2_sequence_model(self, shape, epsilon_val, axis=-1): + """ + [root] --> ReduceL2 -----> Clip --> Expand ----> Div --> + | (axis=-1) (min=epsilon) (shape=root) ^ + | (keepdims=True) | + | | + +-----------------------------------------------+ + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + eps_const = onnx.numpy_helper.from_array(np.array(epsilon_val, dtype=np.float32), "eps_const") + shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const") + + rl2_node = onnx.helper.make_node("ReduceL2", ["root", "axes_const"], ["rl2_out"], keepdims=1) + clip_node = onnx.helper.make_node("Clip", ["rl2_out", "eps_const"], ["clip_out"]) + expand_node = onnx.helper.make_node("Expand", ["clip_out", "shape_const"], ["expand_out"]) + div_node = onnx.helper.make_node("Div", ["root", "expand_out"], ["output"]) + + graph = onnx.helper.make_graph( + [rl2_node, clip_node, expand_node, div_node], + "reducel2_sequence", + [root_inp], + [output], + initializer=[axes_const, eps_const, shape_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def test_fuse_erf_to_gelu_1(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_1_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 2 Gelu nodes. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 2) + + gelu_node_0 = model.model.graph.node[0] + gelu_node_1 = model.model.graph.node[1] + self.assertEqual(gelu_node_0.op_type, "Gelu") + self.assertEqual(gelu_node_1.op_type, "Gelu") + + self.assertTrue(gelu_node_0.name) + self.assertTrue(gelu_node_1.name) + self.assertNotEqual(gelu_node_0.name, gelu_node_1.name) # Generated names should not be equal + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_2(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_2_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_3(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_3_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_4(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_4_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_reduce_l2_to_lpnorm(self): + shape = (1, 2, 3) + model = self.build_reduce_l2_sequence_model(shape, 1e-12, axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LpNormalization node. + modified = FusionLpNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + lpnorm_node = model.model.graph.node[0] + self.assertEqual(lpnorm_node.op_type, "LpNormalization") + self.assertTrue(lpnorm_node.name) + + # LpNorm's p attribute should be set to 2 + p_attr = next(attr for attr in lpnorm_node.attribute if attr.name == "p") + self.assertEqual(p_attr.i, 2) + + def test_fuse_reduce_mean_to_layer_norm(self): + shape = (1, 2, 3) + model = self.build_reduce_mean_sequence_model(shape, [2.0, 2.0, 2.0], [1.0, 1.0, 1.0], axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LayerNormalization node. + modified = FusionLayerNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + layer_norm_node = model.model.graph.node[0] + self.assertEqual(layer_norm_node.op_type, "LayerNormalization") + self.assertTrue(layer_norm_node.name) + + # Check that fused model is equivalent to original model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 74c473d34f548..1056c4ed84510 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1451,6 +1451,13 @@ def generate_build_tree( # tools need to use the symbols. add_default_definition(cmake_extra_defines, "CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "ProgramDatabase") + if number_of_parallel_jobs(args) > 0: + # https://devblogs.microsoft.com/cppblog/improved-parallelism-in-msbuild/ + # NOTE: this disables /MP if set (according to comments on blog post). + # By default, MultiProcMaxCount and CL_MPCount value are equal to the number of CPU logical processors. + # See logic around setting CL_MPCount below + cmake_args += ["-DCMAKE_VS_GLOBALS=UseMultiToolTask=true;EnforceProcessCountAcrossBuilds=true"] + cmake_args += [f"-D{define}" for define in cmake_extra_defines] cmake_args += cmake_extra_args @@ -1662,11 +1669,17 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe build_tool_args = [] if num_parallel_jobs != 1: if is_windows() and args.cmake_generator != "Ninja" and not args.build_wasm: + # https://github.com/Microsoft/checkedc-clang/wiki/Parallel-builds-of-clang-on-Windows suggests + # not maxing out CL_MPCount + # Start by having one less than num_parallel_jobs (default is num logical cores), + # limited to a range of 1..3 + # that gives maxcpucount projects building using up to 3 cl.exe instances each build_tool_args += [ f"/maxcpucount:{num_parallel_jobs}", + # one less than num_parallel_jobs, at least 1, up to 3 + f"/p:CL_MPCount={min(max(num_parallel_jobs - 1, 1), 3)}", # if nodeReuse is true, msbuild processes will stay around for a bit after the build completes "/nodeReuse:False", - f"/p:CL_MPCount={num_parallel_jobs}", ] elif args.cmake_generator == "Xcode": build_tool_args += [ diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index a4bd24b4dd18b..02147c321fab3 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -115,6 +115,7 @@ stages: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() + - job: Linux_Release timeoutInMinutes: 180 workspace: @@ -243,7 +244,46 @@ stages: ln -s /data/models $(Build.BinariesDirectory)/models displayName: link model dir - + - bash: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuild \ + /bin/bash -c " + set -ex; \ + pushd /onnxruntime_src/csharp; \ + dotnet restore /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln; \ + dotnet build /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln -c Release; \ + dotnet test /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln -c Release -f net6.0 --no-build -l \"console;verbosity=normal\"; \ + popd + " + displayName: 'Dotnet build C# sln and Test' + + - bash: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuild \ + /bin/bash -c " + set -ex; \ + /bin/bash /onnxruntime_src/tools/scripts/python_test.sh /onnxruntime_src /build Release && \ + /bin/bash /onnxruntime_src/tools/scripts/symbolic_shape_infer_test.sh /build + " + displayName: 'Run Release tests and symbolic shape infer test' - task: PublishTestResults@2 displayName: 'Publish unit test results' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index 8ca3d9148b514..064e2ea91d194 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -213,13 +213,6 @@ stages: PlatformsSupported: 'linux-x64' VerifyNugetSigning: false - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-GPU' - targetPath: '$(Build.ArtifactStagingDirectory)' - - - task: MSBuild@1 displayName: 'Clean C#' inputs: @@ -241,6 +234,12 @@ stages: parameters: condition: 'succeeded' + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-GPU' + targetPath: '$(Build.ArtifactStagingDirectory)' + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml index f1418e75bffa2..3d128fdb78eee 100644 --- a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml @@ -6,11 +6,8 @@ parameters: steps: - ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: - - task: DeleteFiles@1 - inputs: - SourceFolder: '$(Build.BinariesDirectory)' - contents: | - **/* + - powershell: | + Remove-Item $(Build.BinariesDirectory)/* -Recurse -Force displayName: 'Clean up build directory' - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 8ed22153fd947..e32956d6eb913 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -162,10 +162,11 @@ stages: platform: ${{ parameters.msbuildPlatform }} configuration: RelWithDebInfo msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true + maximumCpuCount: true # default is num logical cores worth of projects building concurrently logProjectEvents: true workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' createLogFile: true + msbuildArgs: "/p:CL_MPCount=2" # 2x cl.exe per project building. - task: PythonScript@0 displayName: 'test' diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index 94f52f476579b..886f19388d01e 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -10,3 +10,4 @@ protobuf==4.21.12 sympy==1.12 flatbuffers neural-compressor>=2.2.1 +triton