From c8ce83967e5b52062558046d769f3af7d871e893 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 19 Jan 2024 15:30:09 -0800 Subject: [PATCH 01/45] Download protoc for all Apple host builds, remove protoc build from iOS packaging pipeline. (#19209) --- .../external/onnxruntime_external_deps.cmake | 74 ++++++++++--------- .../stages/mac-ios-packaging-build-stage.yml | 7 +- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 78f63227c8392..403b4b2c4107a 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -108,41 +108,14 @@ FetchContent_Declare( ) # Download a protoc binary from Internet if needed -if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) +if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) # This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually # download protoc from Protobuf's Github release page and pass the local path to the ONNX_CUSTOM_PROTOC_EXECUTABLE # variable. - message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") - if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") - if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) - FetchContent_Populate(protoc_binary) - endif() - if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") - set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) - endif() - elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") - if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) - FetchContent_Populate(protoc_binary) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") - FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) - FetchContent_Populate(protoc_binary) - endif() - if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") - set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) - endif() - elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin") + if (CMAKE_HOST_APPLE) + # Using CMAKE_CROSSCOMPILING is not recommended for Apple target devices. + # https://cmake.org/cmake/help/v3.26/variable/CMAKE_CROSSCOMPILING.html + # To keep it simple, just download and use the universal protoc binary for all Apple host builds. FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) FetchContent_Populate(protoc_binary) if(protoc_binary_SOURCE_DIR) @@ -150,6 +123,38 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() + elseif (CMAKE_CROSSCOMPILING) + message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") + if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) + FetchContent_Populate(protoc_binary) + endif() + if(protoc_binary_SOURCE_DIR) + message("Use prebuilt protoc") + set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + endif() + elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) + FetchContent_Populate(protoc_binary) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) + FetchContent_Populate(protoc_binary) + endif() + if(protoc_binary_SOURCE_DIR) + message("Use prebuilt protoc") + set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + endif() + endif() endif() endif() @@ -184,9 +189,9 @@ FetchContent_Declare( ) set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) -#TODO: we'd better to turn the following option off. However, it will cause +#TODO: we'd better to turn the following option off. However, it will cause # ".\build.bat --config Debug --parallel --skip_submodule_sync --update" fail with an error message: -# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is +# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is # not in any export set. #set(protobuf_INSTALL OFF CACHE BOOL "Install protobuf binaries and files" FORCE) set(protobuf_USE_EXTERNAL_GTEST ON CACHE BOOL "" FORCE) @@ -562,4 +567,3 @@ endif() FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR) FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR) - diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index d1dff0769e25f..ed32c5d0e15be 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -78,10 +78,6 @@ stages: pip install -r tools/ci_build/github/apple/ios_packaging.requirements.txt displayName: "Install Python requirements" - - script: | - $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_protobuf.sh -p $(Build.BinariesDirectory)/protobuf_install -d $(Build.SourcesDirectory)/cmake/deps.txt - displayName: "Build Host Protoc" - # create and test mobile pods - script: | python tools/ci_build/github/apple/build_and_assemble_apple_pods.py \ @@ -91,8 +87,7 @@ stages: --test \ --variant ${{ parameters.packageVariant }} \ --build-settings-file "${{ variables.buildSettingsFile }}" \ - ${{ variables.optionalIncludeOpsByConfigOption }} \ - -b="--path_to_protoc_exe=$(Build.BinariesDirectory)/protobuf_install/bin/protoc" + ${{ variables.optionalIncludeOpsByConfigOption }} displayName: "Build macOS/iOS framework and assemble pod package files" - script: | From f3402de01e732283283aaa208022d6c7ae85ca4a Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Sun, 21 Jan 2024 10:51:58 -0800 Subject: [PATCH 02/45] [TensorRT EP] Enhance EP context configs in session options and provider options (#19154) Several changes: 1. To align with other EPs' setting of EP context configs in session options, for example [QNN EP](https://github.com/microsoft/onnxruntime/pull/18877), EP context configs for TRT EP can be configured through: 1. Session Options: `ep.context_enable`, `ep.context_file_path` and `ep.context_embed_mode` 2. Provider Options: `trt_dump_ep_context_model`, `trt_ep_context_file_path` and `trt_dump_ep_context_embed_mode` 3. Above setting has 1:1 mapping and provider options has higher priority over session options. ``` Please note that there are rules for using following context model related provider options: 1. In the case of dumping the context model and loading the context model, for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be the absolute path or relative path that is outside of context model directory. It means engine cache needs to be in the same directory or sub-directory of context model. 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory. For example: If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled, if "trt_ep_context_file_path" is "./context_model_dir", - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir" - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir" ``` 2. User can decide the naming of the dumped "EP context" model by using `trt_ep_context_file_path`, please see GetCtxModelPath() for more details. 3. Added suggested comments from https://github.com/microsoft/onnxruntime/pull/18217 --- .../tensorrt/tensorrt_provider_options.h | 28 ++- .../tensorrt/onnx_ctx_model_helper.cc | 211 +++++++++++++----- .../tensorrt/onnx_ctx_model_helper.h | 34 +-- .../tensorrt/tensorrt_execution_provider.cc | 153 ++++++++----- .../tensorrt/tensorrt_execution_provider.h | 5 +- .../tensorrt_execution_provider_info.cc | 13 +- .../tensorrt_execution_provider_info.h | 2 +- .../tensorrt/tensorrt_provider_factory.cc | 9 +- .../core/session/provider_bridge_ort.cc | 87 +++++++- .../python/onnxruntime_pybind_state.cc | 17 +- .../gen_trt_engine_wrapper_onnx_model.py | 19 +- .../providers/tensorrt/tensorrt_basic_test.cc | 208 ++++++++++++++++- 12 files changed, 624 insertions(+), 162 deletions(-) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 60196d0c80cbb..32a9f06464ace 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -11,6 +11,8 @@ /// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions. /// struct OrtTensorRTProviderOptionsV2 { + OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator + int device_id{0}; // cuda device id. int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream. void* user_compute_stream{nullptr}; // user specified CUDA compute stream. @@ -46,8 +48,26 @@ struct OrtTensorRTProviderOptionsV2 { const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT - int trt_dump_ep_context_model{0}; // Dump EP context node model - int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data - int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute - const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix + + /* + * Please note that there are rules for using following context model related provider options: + * + * 1. In the case of dumping the context model and loading the context model, + * for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be + * the absolute path or relative path that is outside of context model directory. + * It means engine cache needs to be in the same directory or sub-directory of context model. + * + * 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory. + * For example: + * If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled, + * if "trt_ep_context_file_path" is "./context_model_dir", + * - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir" + * - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir" + * + */ + int trt_dump_ep_context_model{0}; // Dump EP context node model + const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path. + int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data + + const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix }; diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index 4d8ba6a0891e3..1994d1f5ab0b8 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -38,13 +38,6 @@ const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer) { return main_graph.ModelPath(); } -std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path) { - std::filesystem::path base_path(path.ToPathString()); - std::filesystem::path parent_path = base_path.parent_path(); - std::filesystem::path engine_path = parent_path.append(engine_cache_path); - return engine_path; -} - /* * Update ep_cache_context attribute of the EP context node with the given engine binary data */ @@ -69,14 +62,13 @@ void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, /* * Create "EP context node" model where engine information is embedded */ -ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - bool compute_capability_enable, - std::string compute_capability, - const logging::Logger* logger) { +ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + std::string compute_capability, + const logging::Logger* logger) { auto model_build = graph_viewer.CreateModel(*logger); auto& graph_build = model_build->MainGraph(); @@ -107,21 +99,20 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer, engine_data_str.assign(engine_data, size); } attr_1->set_s(engine_data_str); + LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; } else { attr_1->set_s(engine_cache_path); } + attr_2->set_name(COMPUTE_CAPABILITY); + attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); + attr_2->set_s(compute_capability); + auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); - int num_attributes = compute_capability_enable ? 3 : 2; + int num_attributes = 3; node_attributes->reserve(num_attributes); node_attributes->emplace(EMBED_MODE, *attr_0); node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1); - - if (compute_capability_enable) { - attr_2->set_name(COMPUTE_CAPABILITY); - attr_2->set_type(onnx::AttributeProto_AttributeType_STRING); - attr_2->set_s(compute_capability); - node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2); - } + node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2); // Create EP context node graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN); @@ -138,14 +129,111 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer, } /* - * Dump "EP context node" model + * Return the directory where the ep context model locates + */ +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) { + if (ep_context_file_path.empty()) { + return std::filesystem::path(); + } + std::filesystem::path ctx_path(ep_context_file_path); + if (std::filesystem::is_directory(ep_context_file_path)) { + return ctx_path; + } else { + return ctx_path.parent_path(); + } +} + +/* + * Get "EP context" model path. + * + * Function logic: + * If ep_context_file_path is provided, + * - If ep_context_file_path is a file, return "ep_context_file_path". + * - If ep_context_file_path is a directory, return "ep_context_file_path/original_model_name_ctx.onnx". + * If ep_context_file_path is not provided, + * - Return "original_model_name_ctx.onnx". + * + * TRT EP has rules about context model path and engine cache path (see tensorrt_execution_provider.cc): + * - If dump_ep_context_model_ and engine_cache_enabled_ is enabled, TRT EP will dump context model and save engine cache + * to the same directory provided by ep_context_file_path_. (i.e. engine_cache_path_ = ep_context_file_path_) + * + * Example 1: + * ep_context_file_path = "/home/user/ep_context_model_directory" + * original_model_path = "model.onnx" + * => return "/home/user/ep_context_model_folder/model_ctx.onnx" + * + * Example 2: + * ep_context_file_path = "my_ctx_model.onnx" + * original_model_path = "model.onnx" + * => return "my_ctx_model.onnx" + * + * Example 3: + * ep_context_file_path = "/home/user2/ep_context_model_directory/my_ctx_model.onnx" + * original_model_path = "model.onnx" + * => return "/home/user2/ep_context_model_directory/my_ctx_model.onnx" + * + */ +std::string GetCtxModelPath(const std::string& ep_context_file_path, + const std::string& original_model_path) { + std::string ctx_model_path; + + if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) { + ctx_model_path = ep_context_file_path; + } else { + std::filesystem::path model_path = original_model_path; + std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name + std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx"; + + if (std::filesystem::is_directory(ep_context_file_path)) { + std::filesystem::path model_directory = ep_context_file_path; + ctx_model_path = model_directory.append(ctx_model_name).string(); + } else { + ctx_model_path = ctx_model_name; + } + } + return ctx_model_path; +} + +/* + * Dump "EP context" model * */ -void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string engine_cache_path) { - std::fstream dump(engine_cache_path + "_wrapper.onnx", std::ios::out | std::ios::trunc | std::ios::binary); +void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, + const std::string& ctx_model_path) { + std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary); model_proto->SerializeToOstream(dump); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path + "_wrapper.onnx"; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path; +} + +bool IsAbsolutePath(std::string& path_string) { +#ifdef _WIN32 + onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + return path.is_absolute(); +#else + if (!path_string.empty() && path_string[0] == '/') { + return true; + } + return false; +#endif +} + +// Like "../file_path" +bool IsRelativePathToParentPath(std::string& path_string) { +#ifdef _WIN32 + onnxruntime::PathString ort_path_string = onnxruntime::ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + auto relative_path = path.lexically_normal().make_preferred().wstring(); + if (relative_path.find(L"..", 0) != std::string::npos) { + return true; + } + return false; +#else + if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) { + return true; + } + return false; +#endif } Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) { @@ -157,7 +245,7 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph const int64_t embed_mode = attrs.at(EMBED_MODE).i(); if (embed_mode) { - // Get engine from byte stream + // Get engine from byte stream. const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s(); *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), static_cast(context_binary.length()))); @@ -167,19 +255,41 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph "TensorRT EP could not deserialize engine from binary data"); } } else { - // Get engine from cache file - std::ifstream engine_file(engine_cache_path_.string(), std::ios::binary | std::ios::in); + // Get engine from cache file. + std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s(); + + // For security purpose, in the case of running context model, TRT EP won't allow + // engine cache path to be the relative path like "../file_path" or the absolute path. + // It only allows the engine cache to be in the same directory or sub directory of the context model. + if (IsAbsolutePath(cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path); + } + if (IsRelativePathToParentPath(cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."); + } + + // The engine cache and context model (current model) should be in the same directory + std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); + auto engine_cache_path = ctx_model_dir.append(cache_path); + + if (!std::filesystem::exists(engine_cache_path)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP can't find engine cache: " + engine_cache_path.string() + + ". Please make sure engine cache is in the same directory or sub-directory of context model."); + } + + std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); engine_file.seekg(0, std::ios::end); size_t engine_size = engine_file.tellg(); engine_file.seekg(0, std::ios::beg); std::unique_ptr engine_buf{new char[engine_size]}; engine_file.read((char*)engine_buf.get(), engine_size); *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path_.string(); if (!(*trt_engine_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from cache: " + engine_cache_path_.string()); + "TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string()); } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path.string(); } return Status::OK(); } @@ -193,37 +303,26 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe auto node = graph_viewer.GetNode(0); auto& attrs = node->GetAttributes(); - // Check hardware_architecture(compute_capability) if it's present as an attribute + // Show the warning if compute capability is not matched if (attrs.count(COMPUTE_CAPABILITY) > 0) { std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s(); if (model_compute_capability != compute_capability_) { - LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache doesn't match with the GPU's compute capability"; - LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache: " << model_compute_capability; - LOGS_DEFAULT(ERROR) << "The compute capability of the GPU: " << compute_capability_; - return false; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal"; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_; } } // "embed_mode" attr and "ep_cache_context" attr should be present - if (attrs.count(EMBED_MODE) > 0 && attrs.count(EP_CACHE_CONTEXT) > 0) { - // ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0 - const int64_t embed_mode = attrs.at(EMBED_MODE).i(); - - // engine cache path - if (embed_mode == 0) { - // First assume engine cache path is relatvie to model path, - // If not, then assume the engine cache path is an absolute path. - engine_cache_path_ = LocateEngineRelativeToPath(attrs.at(EP_CACHE_CONTEXT).s(), GetModelPath(graph_viewer)); - auto default_engine_cache_path_ = engine_cache_path_; - if (!std::filesystem::exists(engine_cache_path_)) { - engine_cache_path_.assign(attrs.at(EP_CACHE_CONTEXT).s()); - if (!std::filesystem::exists(engine_cache_path_)) { - LOGS_DEFAULT(ERROR) << "Can't find " << default_engine_cache_path_.string() << " or " << engine_cache_path_.string() << " TensorRT engine"; - return false; - } - } - } + assert(attrs.count(EMBED_MODE) > 0); + assert(attrs.count(EP_CACHE_CONTEXT) > 0); + + const int64_t embed_mode = attrs.at(EMBED_MODE).i(); + if (embed_mode == 1) { + // engine binary data + LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING; } + return true; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index ab6ea733adfa1..bf3bf9e3495d7 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -16,20 +16,27 @@ static const std::string EMBED_MODE = "embed_mode"; static const std::string EP_CACHE_CONTEXT = "ep_cache_context"; static const std::string COMPUTE_CAPABILITY = "hardware_architecture"; static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; +static const std::string EPCONTEXT_WARNING = + "It's suggested to set the ORT graph optimization level to 0 and \ + make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\ + for the best model loading time"; bool GraphHasCtxNode(const GraphViewer& graph_viewer); const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer); -std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path); -ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer, - const std::string engine_cache_path, - char* engine_data, - size_t size, - const int64_t embed_mode, - bool compute_capability_enable, - std::string compute_capability, - const logging::Logger* logger); -void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto, - const std::string engine_cache_path); +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); +ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, + const std::string engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + std::string compute_capability, + const logging::Logger* logger); +std::string GetCtxModelPath(const std::string& ep_context_file_path, + const std::string& original_model_path); +bool IsAbsolutePath(std::string& path_string); +bool IsRelativePathToParentPath(std::string& path_string); +void DumpCtxModel(ONNX_NAMESPACE::ModelProto* model_proto, + const std::string& ctx_model_path); void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto, char* engine_data, size_t size); @@ -38,7 +45,8 @@ class TensorRTCacheModelHandler { public: TensorRTCacheModelHandler(std::unique_ptr* trt_engine, nvinfer1::IRuntime* trt_runtime, - std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), compute_capability_(compute_capability) { + std::string ep_context_model_path, + std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), ep_context_model_path_(ep_context_model_path), compute_capability_(compute_capability) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler); @@ -49,7 +57,7 @@ class TensorRTCacheModelHandler { private: std::unique_ptr* trt_engine_; nvinfer1::IRuntime* trt_runtime_; - std::filesystem::path engine_cache_path_; + std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory std::string compute_capability_; }; // TRTCacheModelHandler } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index aa02d8384afa6..fe6b959b962de 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1079,8 +1079,6 @@ Status BindKernelOutput(Ort::KernelContext& ctx, char const* output_name, size_t output_index, size_t output_type, - std::vector>& scratch_buffers, - OrtAllocator* alloc, cudaStream_t stream) { auto allocator = allocator_map[output_name].get(); auto& shape = allocator->getOutputShape(); @@ -1350,6 +1348,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv timing_cache_enable_ = info.timing_cache_enable; force_timing_cache_match_ = info.force_timing_cache; detailed_build_log_ = info.detailed_build_log; + dump_ep_context_model_ = info.dump_ep_context_model; + ep_context_file_path_ = info.ep_context_file_path; + ep_context_embed_mode_ = info.ep_context_embed_mode; if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { cache_path_ = info.engine_cache_path; cache_prefix_ = info.engine_cache_prefix; @@ -1380,9 +1381,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv profile_max_shapes = info.profile_max_shapes; profile_opt_shapes = info.profile_opt_shapes; cuda_graph_enable_ = info.cuda_graph_enable; - dump_ep_context_model_ = info.dump_ep_context_model; - ep_context_embed_mode_ = info.ep_context_embed_mode; - ep_context_compute_capability_enable_ = info.ep_context_compute_capability_enable; } else { try { const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); @@ -1461,6 +1459,21 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv force_timing_cache_match_ = (std::stoi(timing_force_match_env) == 0 ? false : true); } + const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel); + if (!dump_ep_context_model_env.empty()) { + dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true); + } + + const std::string ep_context_file_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable); + if (!ep_context_file_path_env.empty()) { + ep_context_file_path_ = ep_context_file_path_env; + } + + const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode); + if (!ep_context_embed_mode_env.empty()) { + ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); + } + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath); cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath); @@ -1538,21 +1551,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true); } - const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel); - if (!dump_ep_context_model_env.empty()) { - dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true); - } - - const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode); - if (!ep_context_embed_mode_env.empty()) { - ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); - } - - const std::string ep_context_compute_capability_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable); - if (!ep_context_compute_capability_env.empty()) { - ep_context_compute_capability_enable_ = (std::stoi(ep_context_compute_capability_env) == 0 ? false : true); - } - } catch (const std::invalid_argument& ex) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what(); } catch (const std::out_of_range& ex) { @@ -1580,7 +1578,36 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv dla_core_ = 0; } - if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_ || !cache_prefix_.empty()) { + // If ep_context_file_path_ is provided as a directory, create it if it's not existed + if (dump_ep_context_model_ && !ep_context_file_path_.empty() && std::filesystem::path(ep_context_file_path_).extension().empty() && !std::filesystem::is_directory(ep_context_file_path_)) { + if (!std::filesystem::create_directory(ep_context_file_path_)) { + throw std::runtime_error("Failed to create directory " + ep_context_file_path_); + } + } + + // If dump_ep_context_model_ is enable, TRT EP forces cache_path_ to be the relative path of ep_context_file_path_. + // For example, + // - original cache path = "engine_cache_dir" -> new cache path = "./context_model_dir/engine_cache_dir" + // - original cache path = "" -> new cache path = "./context_model_dir" + // The new cache path will be saved as the "ep_cache_context" node attritue of the EP context node. + // For security reason, it needs to make sure the engine cache is saved inside context model directory. + if (dump_ep_context_model_ && engine_cache_enable_) { + if (IsAbsolutePath(cache_path_)) { + LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, the trt_engine_cache_path should be set with a relative path, but it is an absolute path: " << cache_path_; + } + if (IsRelativePathToParentPath(cache_path_)) { + LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, The trt_engine_cache_path has '..', it's not allowed to point outside the directory."; + } + + // Engine cache relative path to context model directory. + // It's used when dumping the "ep_cache_context" node attribute. + engine_cache_relative_path_to_context_model_dir = cache_path_; + + // Make cache_path_ to be the relative path of ep_context_file_path_ + cache_path_ = GetPathOrParentPathOfCtxModel(ep_context_file_path_).append(cache_path_).string(); + } + + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { if (!cache_path_.empty() && !fs::is_directory(cache_path_)) { if (!fs::create_directory(cache_path_)) { throw std::runtime_error("Failed to create directory " + cache_path_); @@ -1692,6 +1719,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_profile_max_shapes: " << profile_max_shapes << ", trt_profile_opt_shapes: " << profile_opt_shapes << ", trt_cuda_graph_enable: " << cuda_graph_enable_ + << ", trt_dump_ep_context_model: " << dump_ep_context_model_ + << ", trt_ep_context_file_path: " << ep_context_file_path_ + << ", trt_ep_context_embed_mode: " << ep_context_embed_mode_ << ", trt_cache_prefix: " << cache_prefix_; } @@ -2309,6 +2339,14 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, // Construct subgraph capability from node list std::vector> result; + // Get ModelPath + const auto& path_string = graph.ModelPath().ToPathString(); +#ifdef _WIN32 + wcstombs_s(nullptr, model_path_, sizeof(model_path_), path_string.c_str(), sizeof(model_path_)); +#else + strcpy(model_path_, path_string.c_str()); +#endif + // If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation. // So, simply return the ComputeCapability here. @@ -2319,14 +2357,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, return result; } - // Get ModelPath - const auto& path_string = graph.ModelPath().ToPathString(); -#ifdef _WIN32 - wcstombs_s(nullptr, model_path_, sizeof(model_path_), path_string.c_str(), sizeof(model_path_)); -#else - strcpy(model_path_, path_string.c_str()); -#endif - // Generate unique kernel name for TRT graph HashValue model_hash = TRTGenerateId(graph); @@ -2831,10 +2861,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView std::unique_ptr trt_engine; std::unique_ptr trt_context; - // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache - // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity - std::string cache_suffix = ""; std::string cache_path = ""; + std::string cache_suffix = ""; // Customize cache prefix if assigned if (!cache_prefix_.empty()) { // Generate cache suffix in case user would like to customize cache prefix @@ -2843,11 +2871,19 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } else { cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); } + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_; const std::string engine_cache_path = cache_path_prefix + ".engine"; const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; const std::string profile_cache_path = cache_path_prefix + ".profile"; + // Generate file name for dumping ep context model + if (dump_ep_context_model_ && ctx_model_path_.empty()) { + ctx_model_path_ = GetCtxModelPath(ep_context_file_path_, model_path_); + } + if (!has_dynamic_shape) { std::string timing_cache_path = ""; bool engine_update = false; @@ -2984,15 +3020,20 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // dump EP context node model if (dump_ep_context_model_) { - std::unique_ptr model_proto{CreateCtxNodeModel(graph_body_viewer, - engine_cache_path, - reinterpret_cast(serialized_engine->data()), - serialized_engine->size(), - ep_context_embed_mode_, - ep_context_compute_capability_enable_, - compute_capability_, - GetLogger())}; - DumpCtxNodeModel(model_proto.get(), cache_path_prefix); + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + } + + std::unique_ptr model_proto{CreateCtxModel(graph_body_viewer, + ep_cache_context_attr_, + reinterpret_cast(serialized_engine->data()), + serialized_engine->size(), + ep_context_embed_mode_, + compute_capability_, + GetLogger())}; + DumpCtxModel(model_proto.get(), ctx_model_path_); } } } @@ -3052,16 +3093,20 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // TRT EP will serialize the model at inference time due to engine can be updated and the updated engine should be included in the model. // However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize it here. if (dump_ep_context_model_ && has_dynamic_shape) { - model_proto_.reset(CreateCtxNodeModel(graph_body_viewer, - engine_cache_path, - nullptr, - 0, - ep_context_embed_mode_, - ep_context_compute_capability_enable_, - compute_capability_, - GetLogger())); + // "ep_cache_context" node attribute should be a relative path to context model directory + if (ep_cache_context_attr_.empty()) { + auto cache_file_name = std::filesystem::path(engine_cache_path).filename(); + ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string(); + } + model_proto_.reset(CreateCtxModel(graph_body_viewer, + ep_cache_context_attr_, + nullptr, + 0, + ep_context_embed_mode_, + compute_capability_, + GetLogger())); if (ep_context_embed_mode_ == 0) { - DumpCtxNodeModel(model_proto_.get(), cache_path_prefix); + DumpCtxModel(model_proto_.get(), ctx_model_path_); } } @@ -3382,7 +3427,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // dump ep context model if (dump_ep_context_model_ && ep_context_embed_mode_) { UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), serialized_engine->size()); - DumpCtxNodeModel(model_proto_.get(), cache_path_prefix); + DumpCtxModel(model_proto_.get(), ctx_model_path_); } context_update = true; } @@ -3521,7 +3566,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (index_iter != output_indexes.end()) { output_index = index_iter->second; } - auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, alloc, stream); + auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } @@ -3575,7 +3620,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con std::unordered_map output_types; // TRT engine output name -> ORT output tensor type // Get engine binary data and deserialize it - auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, runtime_.get(), compute_capability_); + auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, runtime_.get(), model_path_, compute_capability_); auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -3802,7 +3847,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con if (index_iter != output_indexes.end()) { output_index = index_iter->second; } - auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, alloc, stream); + auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 401a8da119ac2..ad2d2c55c67e1 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -301,8 +301,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { // For create/dump EP context node model bool dump_ep_context_model_ = false; + std::string ep_context_file_path_; int ep_context_embed_mode_ = 0; - bool ep_context_compute_capability_enable_ = true; + std::string ctx_model_path_; + std::string ep_cache_context_attr_; + std::string engine_cache_relative_path_to_context_model_dir; std::unique_ptr model_proto_ = ONNX_NAMESPACE::ModelProto::Create(); std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index 28f6e1720f615..ba9251c71bced 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -47,9 +47,9 @@ constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes"; constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes"; constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes"; constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable"; -constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode"; -constexpr const char* kEpContextComputeCapabilityEnable = "trt_ep_context_compute_capability_enable"; +constexpr const char* kEpContextFilePath = "trt_ep_context_file_path"; +constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; } // namespace provider_option_names } // namespace tensorrt @@ -103,8 +103,8 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) .AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kDumpEpContextModel, info.dump_ep_context_model) + .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) - .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, info.ep_context_compute_capability_enable) .Parse(options)); // add new provider option here. return info; @@ -148,8 +148,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.dump_ep_context_model)}, + {tensorrt::provider_option_names::kEpContextFilePath, MakeStringWithClassicLocale(info.ep_context_file_path)}, {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)}, - {tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, MakeStringWithClassicLocale(info.ep_context_compute_capability_enable)}, }; return options; } @@ -166,6 +166,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor const std::string kProfilesMinShapes_ = empty_if_null(info.trt_profile_min_shapes); const std::string kProfilesMaxShapes_ = empty_if_null(info.trt_profile_max_shapes); const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes); + const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path); const ProviderOptions options{ {tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, @@ -202,9 +203,9 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_}, {tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_}, {tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)}, + {tensorrt::provider_option_names::kEpContextFilePath, kEpContextFilePath_}, {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)}, {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)}, - {tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, MakeStringWithClassicLocale(info.trt_ep_context_compute_capability_enable)}, }; return options; } @@ -299,6 +300,6 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable; trt_provider_options_v2.trt_dump_ep_context_model = internal_options.dump_ep_context_model; trt_provider_options_v2.trt_ep_context_embed_mode = internal_options.ep_context_embed_mode; - trt_provider_options_v2.trt_ep_context_compute_capability_enable = internal_options.ep_context_compute_capability_enable; + trt_provider_options_v2.trt_ep_context_file_path = copy_string_if_needed(internal_options.ep_context_file_path); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index a133ef45affe8..80424b8d6d196 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -52,8 +52,8 @@ struct TensorrtExecutionProviderInfo { std::string profile_opt_shapes{""}; bool cuda_graph_enable{false}; bool dump_ep_context_model{false}; + std::string ep_context_file_path{""}; int ep_context_embed_mode{0}; - bool ep_context_compute_capability_enable{1}; std::string engine_cache_prefix{""}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 62f124afbd1e5..568da57a50956 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -61,13 +61,6 @@ std::unique_ptr TensorrtProviderFactory::CreateProvider() { return std::make_unique(info_); } -std::shared_ptr TensorrtProviderFactoryCreator::Create(int device_id) { - TensorrtExecutionProviderInfo info; - info.device_id = device_id; - info.has_trt_options = false; - return std::make_shared(info); -} - struct Tensorrt_Provider : Provider { void* GetInfo() override { return &g_info; } std::shared_ptr CreateExecutionProviderFactory(int device_id) override { @@ -117,8 +110,8 @@ struct Tensorrt_Provider : Provider { info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes; info.cuda_graph_enable = options.trt_cuda_graph_enable != 0; info.dump_ep_context_model = options.trt_dump_ep_context_model != 0; + info.ep_context_file_path = options.trt_ep_context_file_path == nullptr ? "" : options.trt_ep_context_file_path; info.ep_context_embed_mode = options.trt_ep_context_embed_mode; - info.ep_context_compute_capability_enable = options.trt_ep_context_compute_capability_enable != 0; info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix; return std::make_shared(info); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 45d8006e6b49e..3269c9f0f4e4b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -89,6 +89,10 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; #include "core/providers/cann/cann_provider_options.h" #include "core/providers/dnnl/dnnl_provider_options.h" +#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) +#include "core/session/onnxruntime_session_options_config_keys.h" +#endif + // The filename extension for a shared library is different per platform #ifdef _WIN32 #define LIBRARY_PREFIX @@ -1372,10 +1376,6 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(in return s_library_dnnl.Get().CreateExecutionProviderFactory(use_arena); } -std::shared_ptr TensorrtProviderFactoryCreator::Create(int device_id) { - return s_library_tensorrt.Get().CreateExecutionProviderFactory(device_id); -} - std::shared_ptr MIGraphXProviderFactoryCreator::Create(int device_id) { return s_library_migraphx.Get().CreateExecutionProviderFactory(device_id); } @@ -1419,11 +1419,44 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti trt_options_converted.trt_profile_max_shapes = ""; trt_options_converted.trt_profile_opt_shapes = ""; trt_options_converted.trt_cuda_graph_enable = 0; + trt_options_converted.trt_dump_ep_context_model = 0; + trt_options_converted.trt_ep_context_file_path = ""; + trt_options_converted.trt_ep_context_embed_mode = 0; trt_options_converted.trt_engine_cache_prefix = ""; return trt_options_converted; } +#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) +// Apply configs from session options to TensorRT provider options V2 that are needed for TensorRT EP. +// For example, EP context configs. +void UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(OrtSessionOptions* session_options, OrtTensorRTProviderOptionsV2* tensorrt_options) { + if (session_options) { + auto context_cache_enabled = (session_options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; + tensorrt_options->trt_dump_ep_context_model = context_cache_enabled; + LOGS_DEFAULT(VERBOSE) << "Context cache enable: " << context_cache_enabled; + + auto context_cache_path = (session_options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + tensorrt_options->trt_ep_context_file_path = context_cache_path.c_str(); + LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << tensorrt_options->trt_ep_context_file_path; + + auto embed_mode = (session_options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1"); + if ("1" == embed_mode) { + tensorrt_options->trt_ep_context_embed_mode = 1; + } else if ("0" == embed_mode) { + tensorrt_options->trt_ep_context_embed_mode = 0; + } else { + LOGS_DEFAULT(VERBOSE) << "Invalid ep.context_embed_mode: " << embed_mode << " only 0 or 1 allowed. Set to 1."; + } + LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << tensorrt_options->trt_ep_context_embed_mode; + } +} +#endif + +std::shared_ptr TensorrtProviderFactoryCreator::Create(int device_id) { + return s_library_tensorrt.Get().CreateExecutionProviderFactory(device_id); +} + std::shared_ptr TensorrtProviderFactoryCreator::Create(const OrtTensorRTProviderOptions* provider_options) { OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(provider_options); return s_library_tensorrt.Get().CreateExecutionProviderFactory(&trt_options_converted); @@ -1708,7 +1741,24 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtS ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) { API_IMPL_BEGIN - auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + + std::shared_ptr factory; + +#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) + auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; + // If EP context configs are provided in session options, we need to propagate them to provider options + if (ep_context_cache_enabled_from_sess_options) { + OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); + + onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted); + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); + } else { + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + } +#else + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); +#endif + if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); } @@ -1845,7 +1895,31 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, _In_ Or ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options) { API_IMPL_BEGIN - auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + + std::shared_ptr factory; + +#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) + auto ep_context_cache_enabled_from_provider_options = tensorrt_options->trt_dump_ep_context_model != 0; + auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; + + // If EP context configs are provided in session options, we need to propagate them to provider options. However, + // if provider options already have the EP context configs provided, the configs in session options will be ignored + // since provider options has higher priority than session options. + if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) { + // We need to create a new provider options V2 object and copy from provider_options, due to the "const" object pointed by provider_options can't be modified. + // Note: No need to worry about tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will + // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. + OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options + + onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options); + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); + } else { + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + } +#else + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); +#endif + if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_TensorRT: Failed to load shared library"); } @@ -1991,6 +2065,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor delete[] ptr->trt_profile_min_shapes; delete[] ptr->trt_profile_max_shapes; delete[] ptr->trt_profile_opt_shapes; + delete[] ptr->trt_ep_context_file_path; } std::unique_ptr p(ptr); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d2cd6140b838e..f7ed5520727db 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -475,7 +475,7 @@ std::unique_ptr CreateExecutionProviderInstance( // So we need these std::string variables defined here as they will be kept alive for the lifetime of TRT EP and we can still access them from OrtTensorRTProviderOptionsV2 instance. // (The reason is string copy is involved, for example params.trt_engine_cache_path = cache_path.c_str() and those std::string variable is referenced by OrtTensorRTProviderOptionsV2 instance // and TRT EP instance, so it won't be released.) - std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile; + std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources, trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtTensorRTProviderOptionsV2 params; @@ -728,20 +728,19 @@ std::unique_ptr CreateExecutionProviderInstance( } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_ep_context_model' should be 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "trt_ep_context_file_path") { + if (!option.second.empty()) { + ep_context_file_path = option.second; + params.trt_ep_context_file_path = ep_context_file_path.c_str(); + } else { + ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_file_path' should be a string.\n"); + } } else if (option.first == "trt_ep_context_embed_mode") { if (!option.second.empty()) { params.trt_ep_context_embed_mode = std::stoi(option.second); } else { ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_embed_mode' should be a positive integer number i.e. '1'.\n"); } - } else if (option.first == "trt_ep_context_compute_capability_enable") { - if (option.second == "True" || option.second == "true") { - params.trt_ep_context_compute_capability_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.trt_ep_context_compute_capability_enable = false; - } else { - ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_compute_capability_enable' should be 'True' or 'False'. Default value is 'False'.\n"); - } } else { ORT_THROW("Invalid TensorRT EP option: ", option.first); } diff --git a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py index 717a0816247e7..b94c2cb76a635 100644 --- a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py +++ b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py @@ -15,6 +15,7 @@ def __init__(self, args): engine_cache_path = args.trt_engine_cache_path self.model_name = args.model_name self.dynamic_dim_count = 0 + self.plugins = args.plugins # Get serialized engine from engine cache with open(engine_cache_path, "rb") as file: @@ -25,8 +26,16 @@ def __init__(self, args): else: ep_cache_context_content = engine_cache_path - # Deserialize an TRT engine logger = trt.Logger(trt.Logger.WARNING) + + # Enable TRT plugins + trt.init_libnvinfer_plugins(logger, "") + if len(self.plugins): + import ctypes + + ctypes.CDLL(self.plugins) + + # Deserialize an TRT engine runtime = trt.Runtime(logger) engine = runtime.deserialize_cuda_engine(engine_buffer) num_bindings = engine.num_bindings @@ -165,6 +174,14 @@ def main(): default="trt_engine_wrapper.onnx", type=str, ) + parser.add_argument( + "--plugins", + help="List of plugin paths to load", + required=False, + default=[], + nargs="+", + type=str, + ) args = parser.parse_args() ctor = TensorRTEngineWrapperCreator(args) ctor.create_model() diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 508739ae1d235..4d2538c947dcc 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -122,9 +122,15 @@ void CreateBaseModel(std::string model_name, status = onnxruntime::Model::Save(model, model_name); } -bool HasCacheFileWithPrefix(const std::string& prefix) { - const std::filesystem::path current_dir = std::filesystem::current_path(); - for (const auto& entry : std::filesystem::directory_iterator(current_dir)) { +bool HasCacheFileWithPrefix(const std::string& prefix, std::string file_dir = "") { + std::filesystem::path target_dir; + if (file_dir.empty()) { + target_dir = std::filesystem::current_path(); + } else { + target_dir = std::filesystem::path(file_dir); + } + + for (const auto& entry : std::filesystem::directory_iterator(target_dir)) { if (entry.is_regular_file()) { std::string filename = entry.path().filename().string(); if (filename.rfind(prefix, 0) == 0) { @@ -191,6 +197,8 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string OrtTensorRTProviderOptionsV2 params; params.trt_engine_cache_enable = 1; params.trt_engine_cache_prefix = "TRTEP_Cache_Test"; + params.trt_dump_ep_context_model = 1; + params.trt_ep_context_file_path = "EP_Context_model.onnx"; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); auto status = session_object.Load(model_name); @@ -209,6 +217,9 @@ void RunWithOneSessionSingleThreadInference(std::string model_name, std::string // Verify on cache with customized prefix ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix)); + + // Verify EP context model with user provided name + ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_ep_context_file_path)); } void RunWithOneSessionMultiThreadsInference(std::string model_name, std::string sess_log_id, bool has_non_zero_node = false) { @@ -348,6 +359,192 @@ TEST(TensorrtExecutionProviderTest, TRTModelIdGeneratorUsingModelHashing) { ASSERT_EQ(model_hash, model_hash3) << "model 1&3 are same models and they have same hash, no matter where they are loaded"; } +TEST(TensorrtExecutionProviderTest, EPContextNode) { + std::string model_name = "EPContextNode_test.onnx"; + std::string graph_name = "EPContextNode_test"; + std::string sess_log_id = "EPContextNode_test"; + std::vector dims = {1, 3, 2}; + CreateBaseModel(model_name, graph_name, dims); + + SessionOptions so; + so.session_logid = sess_log_id; + RunOptions run_options; + run_options.run_tag = so.session_logid; + InferenceSession session_object{so, GetEnvironment()}; + auto cuda_provider = DefaultCudaExecutionProvider(); + auto cpu_allocator = cuda_provider->CreatePreferredAllocators()[1]; + std::vector dims_mul_x = {1, 3, 2}; + std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + OrtValue ml_value_x; + CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x); + OrtValue ml_value_y; + CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_y); + OrtValue ml_value_z; + CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_z); + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + feeds.insert(std::make_pair("Y", ml_value_y)); + feeds.insert(std::make_pair("Z", ml_value_z)); + + // prepare outputs + std::vector output_names; + output_names.push_back("M"); + + // prepare expected inputs and outputs + std::vector expected_dims_mul_m = {1, 3, 2}; + std::vector expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; + + /* + * Test case 1: Dump context model + * + * provider options=> + * trt_ep_context_file_path = "EP_Context_model.onnx" + * + * expected result => + * context model "EP_Context_model.onnx" should be created in current directory + * + */ + OrtTensorRTProviderOptionsV2 params; + params.trt_engine_cache_enable = 1; + params.trt_dump_ep_context_model = 1; + params.trt_ep_context_file_path = "EP_Context_model.onnx"; + std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); + EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + auto status = session_object.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()); + ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_ep_context_file_path)); + + /* + * Test case 2: Dump context model + * + * provider options=> + * trt_engine_cache_prefix = "TRT_engine_cache" + * trt_ep_context_file_path = "context_model_folder" + * trt_engine_cache_path = "engine_cache_folder" + * + * expected result => + * engine cache "./context_model_folder/engine_cache_folder/TRT_engine_cache...engine" should be created + * context model "./context_model_folder/EPContextNode_test_ctx.onnx" should be created + */ + InferenceSession session_object2{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params2; + params2.trt_engine_cache_enable = 1; + params2.trt_dump_ep_context_model = 1; + params2.trt_engine_cache_prefix = "TRT_engine_cache"; + params2.trt_engine_cache_path = "engine_cache_folder"; // due to dump_ep_context_model = 1, the new cache path is ./context_model_folder/engine_cache_folder + params2.trt_ep_context_file_path = "context_model_folder"; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms2); + EXPECT_TRUE(session_object2.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object2.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object2.Initialize(); + ASSERT_TRUE(status.IsOK()); + auto new_engine_cache_path = std::filesystem::path(params2.trt_ep_context_file_path).append(params2.trt_engine_cache_path).string(); + // Test engine cache path: + // "./context_model_folder/engine_cache_folder/TRT_engine_cache...engine" should be created + ASSERT_TRUE(HasCacheFileWithPrefix(params2.trt_engine_cache_prefix, new_engine_cache_path)); + // Test context model path: + // "./context_model_folder/EPContextNode_test_ctx.onnx" should be created + ASSERT_TRUE(HasCacheFileWithPrefix("EPContextNode_test_ctx.onnx", params2.trt_ep_context_file_path)); + + /* + * Test case 3: Run the dumped context model + * + * context model path = "./EP_Context_model.onnx" (created from case 1) + * + * expected result=> + * engine cache is also in the same current dirctory as "./xxxxx.engine" + * and the "ep_cache_context" attribute node of the context model should point to that. + * + */ + InferenceSession session_object3{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params3; + model_name = params.trt_ep_context_file_path; + params3.trt_engine_cache_enable = 1; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms3); + EXPECT_TRUE(session_object3.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object3.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object3.Initialize(); + ASSERT_TRUE(status.IsOK()); + // run inference + // TRT engine will be created and cached + // TRT profile will be created and cached only for dynamic input shape + // Data in profile, + // X: 1, 3, 3, 2, 2, 2 + // Y: 1, 3, 3, 2, 2, 2 + // Z: 1, 3, 3, 2, 2, 2 + RunSession(session_object3, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 4: Run the dumped context model + * + * context model path = "./context_model_folder/EPContextNode_test_ctx.onnx" (created from case 2) + * + * expected result=> + * engine cache path is "./context_model_folder/engine_cache_folder/xxxxx.engine" + * and the "ep_cache_context" attribute node of the context model should point to "engine_cache_folder/xxxxx.engine". + * + */ + InferenceSession session_object4{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params4; + model_name = "./context_model_folder/EPContextNode_test_ctx.onnx"; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms4); + EXPECT_TRUE(session_object4.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object4.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object4.Initialize(); + ASSERT_TRUE(status.IsOK()); + // run inference + // TRT engine will be created and cached + // TRT profile will be created and cached only for dynamic input shape + // Data in profile, + // X: 1, 3, 3, 2, 2, 2 + // Y: 1, 3, 3, 2, 2, 2 + // Z: 1, 3, 3, 2, 2, 2 + RunSession(session_object4, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); + + /* + * Test case 5: Dump context model with embed_model = 1 + */ + InferenceSession session_object5{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params5; + params5.trt_dump_ep_context_model = 1; + params5.trt_ep_context_embed_mode = 1; + params5.trt_ep_context_file_path = "EP_Context_model_2.onnx"; + model_name = "EPContextNode_test.onnx"; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms5); + EXPECT_TRUE(session_object5.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object5.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object5.Initialize(); + ASSERT_TRUE(status.IsOK()); + + /* + * Test case 6: Run context model with embed_model = 1 (created from case 5) + */ + InferenceSession session_object6{so, GetEnvironment()}; + OrtTensorRTProviderOptionsV2 params6; + params6.trt_ep_context_embed_mode = 1; + model_name = params5.trt_ep_context_file_path; + execution_provider = TensorrtExecutionProviderWithOptions(¶ms6); + EXPECT_TRUE(session_object6.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); + status = session_object6.Load(model_name); + ASSERT_TRUE(status.IsOK()); + status = session_object6.Initialize(); + ASSERT_TRUE(status.IsOK()); + // run inference + // TRT engine will be created and cached + // TRT profile will be created and cached only for dynamic input shape + // Data in profile, + // X: 1, 3, 3, 2, 2, 2 + // Y: 1, 3, 3, 2, 2, 2 + // Z: 1, 3, 3, 2, 2, 2 + RunSession(session_object6, run_options, feeds, output_names, expected_dims_mul_m, expected_values_mul_m); +} + TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { std::string model_name = "testdata/trt_plugin_custom_op_test.onnx"; SessionOptions so; @@ -448,6 +645,8 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { params.trt_engine_cache_enable = 1; params.trt_engine_cache_prefix = "TRTEP_Cache_Test"; + params.trt_dump_ep_context_model = 1; + params.trt_ep_context_file_path = "EP_Context_model.onnx"; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); auto status = session_object.Load(model_name); @@ -576,6 +775,9 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) { // Verify on cache with customized prefix ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_engine_cache_prefix)); + // Verify EP context model with user provided name + ASSERT_TRUE(HasCacheFileWithPrefix(params.trt_ep_context_file_path)); + if (input_type.compare("static") == 0) { // Can't run inference since input shape changes but the engine is built with static input ASSERT_FALSE(status.IsOK()); From 21034a2c37d707ee913dcca1b00d0b9e7651f980 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Mon, 22 Jan 2024 18:17:11 +0000 Subject: [PATCH 03/45] phi2 contrib ops changes (#19112) ### Description 1. support causal mask in MHA cpu 2. support custom rotary_dim in rotary_emb 3. add bf16 for rotary_emb 4. fix a bug in attention rotary ### Motivation and Context --- docs/ContribOperators.md | 12 +- docs/OperatorKernels.md | 2 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 6 + .../cpu/bert/multihead_attention.cc | 4 +- .../cpu/bert/multihead_attention.h | 1 + .../cpu/bert/multihead_attention_helper.h | 8 +- .../contrib_ops/cpu/bert/rotary_embedding.cc | 47 ++++--- .../contrib_ops/cpu/bert/rotary_embedding.h | 2 + .../cpu/bert/rotary_embedding_helper.h | 55 ++++---- .../cuda/bert/multihead_attention.cc | 3 + .../cuda/bert/multihead_attention.h | 1 + .../contrib_ops/cuda/bert/rotary_embedding.cc | 6 + .../contrib_ops/cuda/bert/rotary_embedding.h | 2 + .../cuda/bert/rotary_embedding_impl.cu | 64 ++++++--- .../cuda/bert/rotary_embedding_impl.h | 1 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../core/graph/contrib_ops/bert_defs.cc | 18 ++- .../contrib_ops/rotary_embedding_op_test.cc | 127 ++++++++++++++++-- 18 files changed, 280 insertions(+), 81 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 45c0e6f822ce9..22e82443167f6 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3031,6 +3031,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
unidirectional : int
+
Whether every token can only attend to previous tokens. Default value is 0.
#### Inputs (1 - 8) @@ -5021,6 +5023,10 @@ This version of the operator has been available since version 1 of the 'com.micr
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
+
num_heads : int
+
Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
+
rotary_embedding_dim : int
+
Rotary embedding dimension. Default value is 0.
scale : float
Custom scale will be used if specified. Default value is 1.0
@@ -5033,9 +5039,9 @@ This version of the operator has been available since version 1 of the 'com.micr
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
sin_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
#### Outputs @@ -5048,7 +5054,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
+
T : tensor(float), tensor(float16), tensor(bfloat16)
Constrain input and output types to float tensors.
M : tensor(int64)
Constrain input and output types to integer tensors
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 394bd7ad2abae..9ecc58bee0725 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -868,7 +868,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 4711ccf487cc8..768676259aa14 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -211,6 +211,12 @@ Status Attention::Compute(OpKernelContext* context) const { relative_position_bias, ¶meters)); + if (parameters.do_rotary) { + ORT_NOT_IMPLEMENTED( + "Rotary embedding is not supported in Attention CPU kernel. \ + Please fuse the model with MHA + RotaryEmbedding."); + } + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int input_hidden_size = parameters.input_hidden_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 694c40bf3eda6..eb25d0fd7cc1e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -40,6 +40,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i num_heads_ = static_cast(num_heads); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; } // Reshape Q/K/V from BxSxD to BxSxNxH @@ -283,8 +284,9 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { nullptr, ¶meters, num_heads_, - scale, mask_filter_value_, + scale, + is_unidirectional_, past_present_share_buffer, false)); diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index 4c86b777e9842..fb7da78a5c0a5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -18,6 +18,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { protected: int num_heads_; // number of attention heads float mask_filter_value_; + bool is_unidirectional_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 00e82c9844b3d..c91f5b601b4e9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -25,6 +25,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing) { // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None @@ -315,7 +316,7 @@ Status CheckInputs(const T* query, output_parameters->head_size = hidden_size / num_heads; output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; - output_parameters->is_unidirectional = false; + output_parameters->is_unidirectional = is_unidirectional; output_parameters->past_present_share_buffer = past_present_share_buffer; output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; @@ -342,6 +343,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing, int max_threads_per_block) { @@ -350,8 +352,8 @@ Status CheckInputs(const T* query, } return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, - past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer, - dmmha_packing); + past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, + past_present_share_buffer, dmmha_packing); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 47f462d75fcc4..aa8b5b5f608fa 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -27,7 +27,13 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); + + if (rotary_embedding_dim > 0) { + ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } } template @@ -42,6 +48,8 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -59,61 +67,66 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; + const int n_heads = parameters.num_heads; const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; - const int half_head_size = head_size / 2; + const int rotary_emb_dim = parameters.rotary_embedding_dim; + const int half_rotary_emb_dim = rotary_emb_dim / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] int head_stride = head_size; - int seq_stride = num_heads * head_stride; + int seq_stride = n_heads * head_stride; int batch_stride = sequence_length * seq_stride; if (parameters.transposed) { - // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] seq_stride = head_size; head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; + batch_stride = n_heads * head_stride; } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); - const int loop_len = batch_size * sequence_length * num_heads; - const double cost = static_cast(head_size); + const int loop_len = batch_size * sequence_length * n_heads; + const double cost = static_cast(rotary_emb_dim); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { - const int b = static_cast((ptr / num_heads) / sequence_length); - const int s = static_cast((ptr / num_heads) % sequence_length); - const int n = static_cast(ptr % num_heads); + const int b = static_cast((ptr / n_heads) / sequence_length); + const int s = static_cast((ptr / n_heads) % sequence_length); + const int n = static_cast(ptr % n_heads); const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input_src + block_offset; T* output_data = output_dest + block_offset; - // Cache is (M, H/2) + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) const int position_id = (position_ids_format == 0) ? static_cast(pos_ids_data[0]) + s : static_cast(pos_ids_data[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; + const int cache_offset = position_id * half_rotary_emb_dim; const T* cos_data = cos_cache_data + cache_offset; const T* sin_data = sin_cache_data + cache_offset; int cache_idx = 0; T sign = 0; int j = 0; - for (int i = 0; i < head_size; i++) { + for (int i = 0; i < rotary_emb_dim; i++) { if (interleaved) { - cache_idx = (i / 2) % half_head_size; + cache_idx = (i / 2) % half_rotary_emb_dim; sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); - j = (i + half_head_size) % head_size; + cache_idx = i % half_rotary_emb_dim; + sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } + for (int i = rotary_emb_dim; i < head_size; i++) { + output_data[i] = input_data[i]; + } } }); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h index be834a66cdc69..4e32424a22b6c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -16,6 +16,8 @@ class RotaryEmbedding final : public OpKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index 7b2e8289f7b06..dcbb36d1c4a3c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -11,14 +11,15 @@ namespace rotary_embedding_helper { // Parameters deduced from node attributes and inputs/outputs. struct RotaryParameters { - int batch_size; // Batch size used by input - int sequence_length; // Sequence length used by input - int hidden_size; // Hidden size used by input - int head_size; // Head size used by cos/sin cache * 2 - int num_heads; // num_heads = hidden_size / head_size - int max_sequence_length; // Sequence length used by cos/sin cache - int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) - bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size + int rotary_embedding_dim; // Rotary embedding dimension. + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -26,11 +27,13 @@ Status CheckInputs(const T* input, const T* position_ids, const T* cos_cache, const T* sin_cache, + int num_heads, + int rotary_embedding_dim, void* parameters) { // input : (batch_size, sequence_length, hidden_size) // position ids : (1) or (batch_size, sequence_length) - // cos cache : (max_sequence_length, head_size / 2) - // sin cache : (max_sequence_length, head_size / 2) + // cos cache : (max_sequence_length, rotary_embedding_dim / 2) + // sin cache : (max_sequence_length, rotary_embedding_dim / 2) // Check input const auto& input_dims = input->Shape().GetDims(); @@ -60,6 +63,12 @@ Status CheckInputs(const T* input, "the same shape"); } + // Check num_heads and rotary_embedding_dim + if (rotary_embedding_dim > 0 && num_heads == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ", + "specified"); + } + // Get attributes from inputs int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); @@ -73,8 +82,13 @@ Status CheckInputs(const T* input, transposed = true; } int max_sequence_length = static_cast(cos_cache_dims[0]); - int head_size = static_cast(cos_cache_dims[1]) * 2; - int num_heads = hidden_size / head_size; + int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2 + : static_cast(hidden_size / num_heads); + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + int position_ids_format = -1; // Check position_ids input shapes @@ -91,23 +105,15 @@ Status CheckInputs(const T* input, } else { position_ids_format = 0; } + // Check cos_cache input shapes if (max_sequence_length != static_cast(cos_cache_dims[0])) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", "max_sequence_length, got ", cos_cache_dims[0]); } - if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", - "head_size / 2, got ", cos_cache_dims[1]); - } - // Check sin_cache input shapes - if (max_sequence_length != static_cast(sin_cache_dims[0])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", - "max_sequence_length, got ", sin_cache_dims[0]); - } - if ((head_size / 2) != static_cast(sin_cache_dims[1])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", - "head_size / 2, got ", sin_cache_dims[1]); + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); } // Set rotary parameters @@ -117,10 +123,11 @@ Status CheckInputs(const T* input, output_parameters->sequence_length = sequence_length; output_parameters->hidden_size = hidden_size; output_parameters->head_size = head_size; - output_parameters->num_heads = num_heads; + output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; output_parameters->transposed = transposed; + output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index ebd66d8c6528e..f978f50c6851f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -44,6 +44,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); disable_fused_self_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); @@ -105,6 +107,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { num_heads_, mask_filter_value_, scale_, + is_unidirectional_, false, // past_present_share_buffer false, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index c162f7133cc1c..86a32c92ce003 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public CudaKernel { int num_heads_; // number of attention heads float mask_filter_value_; float scale_; + bool is_unidirectional_; bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 2d12e975d88d7..9de7ba3885c3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -29,10 +29,13 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); } @@ -48,6 +51,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -71,6 +76,7 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length, parameters.num_heads, parameters.head_size, + parameters.rotary_embedding_dim, parameters.max_sequence_length, parameters.position_ids_format, interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h index 6dab2ad56749e..d52f61d670444 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -19,6 +19,8 @@ class RotaryEmbedding final : public CudaKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index e1b83bd8caf54..c6637041f05bd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -26,6 +26,7 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int position_ids_format, const bool interleaved, const int batch_stride, @@ -33,24 +34,33 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently - + const int b = blockIdx.z; const int s = blockIdx.y; const int n = blockIdx.x; const int i = threadIdx.x; + if (i >= head_size) { + return; + } + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input + block_offset; T* output_data = output + block_offset; + if (i >= rotary_embedding_dim) { + output_data[i] = input_data[i]; + return; + } + // Cache is (M, H/2) - const int half_head_size = head_size / 2; + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; const int position_id = (position_ids_format == 0) ? \ static_cast(position_ids[0]) + s \ : static_cast(position_ids[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; + const int cache_offset = position_id * half_rotary_embedding_dim; const T* cos_data = cos_cache + cache_offset; const T* sin_data = sin_cache + cache_offset; @@ -58,13 +68,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH T sign = 0; int j = 0; if (interleaved) { - cache_idx = (i / 2) % half_head_size; + cache_idx = (i / 2) % half_rotary_embedding_dim; sign = (i % 2 == 0) ? -1 : 1; j = (i % 2 == 0) ? i+1 : i-1; // i - sign } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? -1 : 1; - j = (i + half_head_size) % head_size; + cache_idx = i % half_rotary_embedding_dim; + sign = (i < half_rotary_embedding_dim) ? -1 : 1; + j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } @@ -82,20 +92,23 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool transposed) { - - constexpr int smem_size = 0; - const dim3 grid(num_heads, sequence_length, batch_size); - const dim3 block(head_size, 1, 1); - // Note: Current implementation assumes head_size <= max_threads_per_block // because head_size is currently large for LLaMA-2. For smaller head_size // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. + ORT_ENFORCE(head_size <= max_threads_per_block, + "Rotary embedding dim must be <= max_threads_per_block"); + + int tpb = (head_size + 31)/32*32; + + const dim3 block(tpb); + const dim3 grid(num_heads, sequence_length, batch_size); // Default input tensor shape is [batch, seq, hidden_size] int head_stride = head_size; @@ -109,10 +122,9 @@ Status LaunchRotaryEmbeddingKernel( } assert(head_size <= max_threads_per_block); - RotaryEmbeddingBSNH<<>>( - output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved, - batch_stride, seq_stride, head_stride + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size, + rotary_embedding_dim, position_ids_format, interleaved, batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -129,6 +141,7 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, @@ -146,6 +159,25 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool transposed); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + BFloat16* output, + const BFloat16* input, + const int64_t* position_ids, + const BFloat16* cos_cache, + const BFloat16* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index ee1ccc43dcbff..36300fe7a660f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -21,6 +21,7 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 34b44694a5fcc..fa73950c9c6f5 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); @@ -299,6 +300,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 0317ffcfb0e31..7f34647f1faef 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -927,6 +927,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("unidirectional", + "Whether every token can only attend to previous tokens. Default value is 0.", + AttributeProto::INT, + static_cast(0)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)", @@ -1145,6 +1149,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Rotate using interleaved pattern. Default value is 0 (False).", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("rotary_embedding_dim", + "Rotary embedding dimension. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("num_heads", + "Number of attention heads. Default value is 0. Must use with rotary_embedding_dim", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "input", "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", @@ -1155,17 +1167,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M") .Input(2, "cos_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Input(3, "sin_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Output(0, "output", "tensor with same shape as input.", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 55f01bf0d3f1d..e64de0e6da16a 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -11,6 +11,14 @@ namespace onnxruntime { namespace test { +namespace { +enum class TensorType { + kFloat, + kFloat16, + kBFloat16 +}; +} // anonymous namespace + static void RunTest( const std::vector& input_data, const std::vector& position_ids, @@ -20,10 +28,11 @@ static void RunTest( int batch_size, int sequence_length, int head_size, + int rotary_embedding_dim, int num_heads, int max_sequence_length, int64_t interleaved, - bool use_float16, + TensorType tensor_type, bool disable_cpu, bool disable_cuda, bool disable_dml) { @@ -36,7 +45,9 @@ static void RunTest( int hidden_size = num_heads * head_size; std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector pos_dims; - std::vector cache_dims = {max_sequence_length, head_size / 2}; + std::vector cache_dims = {max_sequence_length, rotary_embedding_dim > 0 + ? rotary_embedding_dim / 2 + : head_size / 2}; assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); assert(max_sequence_length >= sequence_length); @@ -49,7 +60,10 @@ static void RunTest( std::string op_type = "RotaryEmbedding"; std::vector> execution_providers; - int min_cuda_architecture = use_float16 ? 530 : 0; + int min_cuda_architecture = (tensor_type == TensorType::kBFloat16) + ? 800 + : (tensor_type == TensorType::kFloat16) ? 530 + : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; @@ -59,7 +73,7 @@ static void RunTest( if (enable_dml && !disable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } - if (!use_float16 && !disable_cpu) { + if (tensor_type == TensorType::kFloat && !disable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } if (execution_providers.size() == 0) { @@ -70,20 +84,36 @@ static void RunTest( OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddAttribute("interleaved", interleaved); - if (!use_float16) { + if (rotary_embedding_dim > 0) { + test.AddAttribute("rotary_embedding_dim", rotary_embedding_dim); + test.AddAttribute("num_heads", num_heads); + } + + if (tensor_type == TensorType::kFloat) { test.AddInput("input", input_dims, input_data); test.AddInput("position_ids", pos_dims, position_ids); test.AddInput("cos_cache", cache_dims, cos_cache); test.AddInput("sin_cache", cache_dims, sin_cache); test.AddOutput("output", input_dims, output_data); - } else { + } else if (tensor_type == TensorType::kFloat16) { test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("position_ids", pos_dims, position_ids); test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache)); test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache)); test.AddOutput("output", input_dims, ToFloat16(output_data)); + } else { + test.AddInput("input", input_dims, FloatsToBFloat16s(input_data)); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddInput("cos_cache", cache_dims, FloatsToBFloat16s(cos_cache)); + test.AddInput("sin_cache", cache_dims, FloatsToBFloat16s(sin_cache)); + test.AddOutput("output", input_dims, FloatsToBFloat16s(output_data)); + } + if (tensor_type == TensorType::kBFloat16) { + test.SetOutputAbsErr("output", 0.03f); + } else { + test.SetOutputAbsErr("output", 0.002f); } - test.SetOutputAbsErr("output", 0.002f); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -95,10 +125,12 @@ static void RunTests(const std::vector& input_data, int batch_size, int sequence_length, int head_size = 0, + int rotary_embedding_dim = 0, int num_heads = 0, int max_sequence_length = 0, int64_t interleaved = 0, - bool use_float16 = true) { + bool use_float16 = true, + bool disable_dml = false) { // FP32 test for CPU RunTest(input_data, position_ids, @@ -108,10 +140,11 @@ static void RunTests(const std::vector& input_data, batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved, - false, /* use_fp16 */ + TensorType::kFloat, false, /* disable_cpu */ true, /* disable_cuda */ true /* disable_dml */); @@ -125,13 +158,14 @@ static void RunTests(const std::vector& input_data, batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved, - false, /* use_fp16 */ + TensorType::kFloat, false, /* disable_cpu */ false, /* disable_cuda */ - false /* disable_dml */); + disable_dml || false /* disable_dml */); // FP16 test for CUDA and DML if (use_float16) { @@ -143,13 +177,31 @@ static void RunTests(const std::vector& input_data, batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved, - true, /* use_fp16 */ + TensorType::kFloat16, true, /* disable_cpu */ false, /* disable_cuda*/ - false /* disable_dml */); + disable_dml || false /* disable_dml */); + + // RunTest(input_data, + // position_ids, + // cos_cache, + // sin_cache, + // output_data, + // batch_size, + // sequence_length, + // head_size, + // rotary_embedding_dim, + // num_heads, + // max_sequence_length, + // interleaved, + // TensorType::kBFloat16, + // true, /* disable_cpu */ + // false, /* disable_cuda*/ + // false /* disable_dml */); } } @@ -159,6 +211,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { int sequence_length = 3; int num_heads = 2; int head_size = 4; + int rotary_embedding_dim = 0; int max_sequence_length = 8; int64_t interleaved = 1; // true @@ -190,6 +243,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved); @@ -201,6 +255,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { int sequence_length = 8; int num_heads = 4; int head_size = 6; + int rotary_embedding_dim = 0; int max_sequence_length = 16; int64_t interleaved = 1; // true @@ -388,6 +443,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved); @@ -399,6 +455,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { int sequence_length = 8; int num_heads = 4; int head_size = 6; + int rotary_embedding_dim = 0; int max_sequence_length = 16; int64_t interleaved = 0; // false @@ -586,6 +643,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved); @@ -597,6 +655,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { int sequence_length = 2; int num_heads = 3; int head_size = 6; + int rotary_embedding_dim = 0; int max_sequence_length = 4; int64_t interleaved = 0; // false @@ -632,10 +691,52 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { batch_size, sequence_length, head_size, + rotary_embedding_dim, num_heads, max_sequence_length, interleaved); } +TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 1; + int head_size = 6; + int rotary_embedding_dim = 4; + int max_sequence_length = 2; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f}; + + std::vector position_ids = {0, 1}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.0427f, + -0.2250f, -0.8673f, -1.5071f, -0.4586f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved, + true, /*use_fp16*/ + true /*disable_dml*/); +} + } // namespace test } // namespace onnxruntime From 373ebac167a1d7f1dfb7a576a6c92f1c75cb711e Mon Sep 17 00:00:00 2001 From: Zhang Lei Date: Mon, 22 Jan 2024 10:40:48 -0800 Subject: [PATCH 04/45] Zhalei/fix seqoutput type (#18765) After refactoring beamsearch, all scores become fp32. Yet it need support fp16 according to original specs. --- .../cpu/transformers/beam_search_impl_gpt.h | 8 +- .../cpu/transformers/beam_search_impl_t5.h | 8 +- .../transformers/beam_search_impl_whisper.h | 8 +- .../cpu/transformers/beam_search_scorer.cc | 82 +++++++++++++------ .../cpu/transformers/beam_search_scorer.h | 12 +-- .../cpu/transformers/generation_shared.h | 3 + .../cuda/transformers/generation_cuda_impl.cu | 63 +++++++++++++- .../cuda/transformers/generation_cuda_impl.h | 19 +++-- .../transformers/generation_device_helper.cc | 55 +++++++++++-- .../models/whisper/whisper_chain.py | 31 ++++++- 10 files changed, 220 insertions(+), 69 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 56d950ca2f41e..dc72a038c3d58 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -397,12 +397,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 94547887d3a90..cd891a9508019 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -404,12 +404,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 91b93a125ad7a..4d6643c68a98b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -500,12 +500,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 7e2e5b2129221..0eccbe26605f5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -50,11 +50,12 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con return beams_.back().score < current_score; } +template void BeamHypotheses::Output( int top_k, int max_length, - gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) - gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty + gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) + gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty { // Copy the top_k beams into the sequences ORT_ENFORCE(top_k <= beams_used_); @@ -67,7 +68,7 @@ void BeamHypotheses::Output( gsl::copy(item.hypothesis, target); if (!sequences_scores.empty()) - sequences_scores[index] = item.score; + sequences_scores[index] = (T)item.score; } } @@ -181,21 +182,21 @@ void BeamSearchScorer::Process(ISequences& sequences, } } -void BeamSearchScorer::Finalize(ISequences& sequences, - gsl::span& final_beam_scores, - Tensor* output_sequences, - Tensor* output_sequence_scores) { - ORT_ENFORCE(output_sequences != nullptr); - +template +void OutputSequenceScores(BeamSearchScorer* scorer, + ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { // Finalize all open beam hypotheses and add to generated hypotheses. - for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; if (beam_hyp.done_) { continue; } - for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) { - size_t batch_beam_index = batch_index * num_beams_ + beam_index; + for (size_t beam_index = 0; beam_index < scorer->num_beams_; beam_index++) { + size_t batch_beam_index = batch_index * scorer->num_beams_ + beam_index; float final_score = final_beam_scores[batch_beam_index]; auto final_tokens = sequences.GetSequence(narrow(batch_beam_index)); beam_hyp.Add(final_tokens, final_score); @@ -206,26 +207,59 @@ void BeamSearchScorer::Finalize(ISequences& sequences, gsl::span output = output_sequences->MutableDataAsSpan(); // Fill output sequences with pad token ID so that we do not need append it later. - std::fill_n(output.data(), output.size(), pad_token_id_); + std::fill_n(output.data(), output.size(), scorer->pad_token_id_); // Score of each sequence, with shape (batch_size * num_return_sequences). - gsl::span sequence_scores; + gsl::span sequence_scores; if (output_sequence_scores) { - sequence_scores = output_sequence_scores->MutableDataAsSpan(); + sequence_scores = output_sequence_scores->MutableDataAsSpan(); } // Select the best hypotheses according to number of sequences to return. - for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; - auto batch_output = output.subspan(batch_index * num_return_sequences_ * max_length_, - num_return_sequences_ * max_length_); - gsl::span sequence_scores_buffer; + auto batch_output = output.subspan(batch_index * scorer->num_return_sequences_ * scorer->max_length_, + scorer->num_return_sequences_ * scorer->max_length_); + gsl::span sequence_scores_buffer; if (!sequence_scores.empty()) - sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences_, num_return_sequences_); + sequence_scores_buffer = sequence_scores.subspan(batch_index * scorer->num_return_sequences_, scorer->num_return_sequences_); + + beam_hyp.template Output(narrow(scorer->num_return_sequences_), narrow(scorer->max_length_), batch_output, + sequence_scores_buffer); + } +} + +void BeamSearchScorer::Finalize(ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + ORT_ENFORCE(output_sequences != nullptr); - beam_hyp.Output(narrow(num_return_sequences_), narrow(max_length_), batch_output, - sequence_scores_buffer); + if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { + OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } else { + ORT_ENFORCE(output_sequence_scores->IsDataType()); + OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } +} + +void BeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { + if (output_scores) { + if (output_scores->IsDataType()) { + gsl::span target = output_scores->MutableDataAsSpan(); + ORT_ENFORCE(target.size() == final_scores.size()); + std::copy_n(final_scores.data(), final_scores.size(), target.data()); + } else { + ORT_ENFORCE(output_scores->IsDataType()); + gsl::span target = output_scores->MutableDataAsSpan(); + ORT_ENFORCE(target.size() == final_scores.size()); + const float* src = final_scores.data(); + MLFloat16* dst = target.data(); + for (size_t i = 0; i < target.size(); i++) { + dst[i] = MLFloat16(src[i]); + } + } } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index 94b6d340d9f4a..dc92e8038a68e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -35,10 +35,11 @@ struct BeamHypotheses { bool CanImprove(float best_sum_logprobs, int current_length) const; // Output results - void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) - gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + template + void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) + gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) gsl::span beams_; // Beam width sized array of hypotheses, sorted by highest scoring int beams_used_; // Number of elements used in beams_ @@ -60,13 +61,14 @@ struct BeamSearchScorer : IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) override; + void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; + bool IsDone() const override { return not_done_count_ == 0; } gsl::span GetNextScores() override { return next_beam_scores_; } gsl::span GetNextTokens() override { return next_beam_tokens_; } gsl::span GetNextIndicesCPU() override { return next_beam_indices_; } - private: size_t batch_size_; size_t num_beams_; size_t max_length_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index f6faf2e325f8f..cb62e2f7bf4da 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -120,6 +120,9 @@ struct IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) = 0; + virtual void OutputScores(gsl::span& final_scores, + Tensor* output_scores) = 0; + virtual bool IsDone() const = 0; // GPU version will return false here, as it asynchronously queues up the event virtual bool IsDoneLater() const { return false; } // GPU version waits for the asynchous result to complete here diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index dbd7fb010462d..a39abefed9cd0 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -307,12 +307,13 @@ __device__ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_ return beams_[beams_count_ - 1].score < current_score; } +template __device__ void BeamHypotheses::Output( int top_k, int max_length, int pad_token_id, int32_t* sequences, // buffer of shape (num_return_sequences, max_length) - float* sequences_scores) // buffer of shape (num_return_sequences) or empty + T* sequences_scores) // buffer of shape (num_return_sequences) or empty { // Copy the top_k beams into the sequences for (int index = 0; index < top_k; index++) { @@ -327,7 +328,7 @@ __device__ void BeamHypotheses::Output( target[i] = pad_token_id; if (sequences_scores) - sequences_scores[index] = item.score; + sequences_scores[index] = (T)item.score; } } @@ -501,13 +502,14 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp next_beam_tokens.data()); } +template __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, const int32_t* sequences_buffer, int sequence_length, BeamHypotheses* beam_hyps_, const float* final_beam_scores, int32_t* output, - float* sequence_scores) { + T* sequence_scores) { int batch_index = blockIdx.x * blockDim.x + threadIdx.x; if (batch_index >= state.batch_size_) return; @@ -534,6 +536,7 @@ __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, sequence_scores ? sequence_scores + batch_index * state.num_return_sequences_ : nullptr); } +template void LaunchBeamSearchScorer_Finalize(int batch_size, BeamScorerState& state, gsl::span sequences, @@ -541,7 +544,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, gsl::span beam_hyps, gsl::span final_beam_scores, gsl::span output, - gsl::span sequence_scores, + gsl::span sequence_scores, cudaStream_t stream) { BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state, sequences.data(), @@ -552,6 +555,58 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, sequence_scores.data()); } +template void LaunchBeamSearchScorer_Finalize( + int batch_size, + BeamScorerState& state, + gsl::span sequences, + int sequence_length, + gsl::span beam_hyps, + gsl::span final_beam_scores, + gsl::span output, + gsl::span sequence_scores, + cudaStream_t stream); + +template void LaunchBeamSearchScorer_Finalize<__half>( + int batch_size, + BeamScorerState& state, + gsl::span sequences, + int sequence_length, + gsl::span beam_hyps, + gsl::span final_beam_scores, + gsl::span output, + gsl::span<__half> sequence_scores, + cudaStream_t stream); + +template +__global__ void FloatConvertAndCopyKernel(const float* src, T* dst, size_t total_elements) { + int64_t index = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (index < total_elements) { + dst[index] = (T)src[index]; + } +} + +template +void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream) { + ORT_ENFORCE(final_scores.size() == output_scores.size()); + constexpr unsigned ThreadPerBlock = 256; + unsigned num_blocks = (unsigned)((final_scores.size() + (ThreadPerBlock - 1))/ ThreadPerBlock); + + typedef typename ToCudaType::MappedType CudaT; + + FloatConvertAndCopyKernel<<>>( + final_scores.data(), (CudaT*)output_scores.data(), final_scores.size()); +} + +template void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + +template void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + __global__ void AddProbsKernel(float* log_probs, float* cum_log_probs, const int vocab_size, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 5ed5949196b29..281cb6c725975 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -65,11 +65,12 @@ struct BeamHypotheses { __device__ bool CanImprove(float best_sum_logprobs, int current_length) const; // Output results - __device__ void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - int pad_token_id, // pad token - int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) - float* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + template + __device__ void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + int pad_token_id, // pad token + int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) + T* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) }; struct BeamScorerState { @@ -110,6 +111,7 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp gsl::span next_beam_indices, cudaStream_t stream); +template void LaunchBeamSearchScorer_Finalize(int batch_size, BeamScorerState& state, gsl::span sequences, @@ -117,9 +119,14 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, gsl::span beam_hyps_, gsl::span final_beam_scores, gsl::span output, - gsl::span sequence_scores, + gsl::span sequence_scores, cudaStream_t stream); +template +void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + void LaunchNextTokenKernel(const int64_t* next_token_indices, int32_t* next_indices, int32_t* next_tokens, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 380d561bbb23c..bba30805ae1be 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -620,6 +620,8 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) override; + void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; + bool IsDone() const override { return false; } // For CUDA we speculatively run the next step while we wait for the GPU to report status. We use 'IsDoneLater()' for this bool IsDoneLater() const override; @@ -632,7 +634,6 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { } gsl::span GetNextIndicesGPU() override { return next_beam_indices_; } - private: mutable cuda::AutoDestoryCudaEvent event_process_complete_; IAllocatorUniquePtr state_cpu_; IAllocatorUniquePtr state_gpu_; @@ -743,22 +744,58 @@ bool CudaBeamSearchScorer::IsDoneLater() const { return state_cpu_->not_done_count_ == 0; } +template +void CudaOutputSequenceScores(CudaBeamSearchScorer* scorer, + transformers::ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). + gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; + + // Score of each sequence, with shape (batch_size * num_return_sequences). + using CudaT = typename ToCudaType::MappedType; + gsl::span sequence_scores; + if (output_sequence_scores) { + sequence_scores = gsl::span{(CudaT*)output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; + } + + cuda::LaunchBeamSearchScorer_Finalize(scorer->state_cpu_->batch_size_, + *scorer->state_gpu_, + sequences.GetCurrentDeviceSequences(), + sequences.GetSequenceLength(), + scorer->beam_hyps_, + final_beam_scores, + output, + sequence_scores, + scorer->stream_); +} + void CudaBeamSearchScorer::Finalize(transformers::ISequences& sequences, gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) { ORT_ENFORCE(output_sequences != nullptr); - // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). - gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; - - // Score of each sequence, with shape (batch_size * num_return_sequences). - gsl::span sequence_scores; - if (output_sequence_scores) { - sequence_scores = gsl::span{output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; + if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { + CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } else { + ORT_ENFORCE(output_sequence_scores->IsDataType()); + CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); } +} - cuda::LaunchBeamSearchScorer_Finalize(state_cpu_->batch_size_, *state_gpu_, sequences.GetCurrentDeviceSequences(), sequences.GetSequenceLength(), beam_hyps_, final_beam_scores, output, sequence_scores, stream_); +void CudaBeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { + if (output_scores) { + if (output_scores->IsDataType()) { + gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); + cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); + } else { + ORT_ENFORCE(output_scores->IsDataType()); + gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); + cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); + } + } } std::unique_ptr CreateBeamScorer(const transformers::IGenerationParameters& parameters, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 33958e55f8c38..a74666b7af297 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -53,9 +53,9 @@ def chain_model(args): beam_outputs = ["sequences"] if args.output_sequence_scores: - beam_outputs.append("sequence_scores") + beam_outputs.append("sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores") if args.output_scores: - beam_outputs.append("scores") + beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") if args.use_whisper_beamsearch: assert len(beam_inputs) == 12 @@ -75,6 +75,7 @@ def chain_model(args): beam_outputs.extend(["no_speech_probs_beam"]) input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None + output_scores_cast_node = output_sequence_scores_cast_node = None if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( "Cast", @@ -97,6 +98,22 @@ def chain_model(args): name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) + if args.output_sequence_scores: + output_sequence_scores_cast_node = helper.make_node( + "Cast", + inputs=["sequence_scores_fp16"], + outputs=["sequence_scores"], + name="CastOutputSequenceScoresToFp32", + to=TensorProto.FLOAT, + ) + if args.output_scores: + output_scores_cast_node = helper.make_node( + "Cast", + inputs=["scores_fp16"], + outputs=["scores"], + name="CastScoresToFp32", + to=TensorProto.FLOAT, + ) operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch" node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") @@ -214,10 +231,18 @@ def chain_model(args): opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] graph_nodes = ( - [input_features_cast_node, len_pen_cast_node, rep_pen_cast_node, node] + [ + input_features_cast_node, + len_pen_cast_node, + rep_pen_cast_node, + node, + output_sequence_scores_cast_node, + output_scores_cast_node, + ] if args.precision == Precision.FLOAT16 else [node] ) + graph_nodes = [node for node in graph_nodes if node is not None] if args.output_no_speech_probs: prob_cast_node = helper.make_node( "Cast", From 8d9d7511799e2138c14454bb672caf07dcdc2457 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 22 Jan 2024 12:47:42 -0800 Subject: [PATCH 05/45] [QNN EP] Expose device-level session options (#19212) ### Description - Adds the following session options to configure the device: - `soc_model`: The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). - `htp_arch`: The minimum HTP architecture the driver will use to select compatible QNN operators. - `device_id`: The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). ### Motivation and Context Allow more configuration. --- .../core/session/onnxruntime_c_api.h | 8 ++ .../qnn/builder/qnn_backend_manager.cc | 31 ++++++- .../qnn/builder/qnn_backend_manager.h | 14 ++- .../qnn/builder/qnn_configs_helper.h | 90 +++++++++++++++++++ .../qnn/builder/qnn_graph_configs_helper.cc | 43 --------- .../qnn/builder/qnn_graph_configs_helper.h | 56 ------------ .../providers/qnn/qnn_execution_provider.cc | 69 ++++++++++++-- .../providers/qnn/qnn_execution_provider.h | 7 +- onnxruntime/test/onnx/main.cc | 18 +++- .../test/perftest/command_args_parser.cc | 4 + onnxruntime/test/perftest/ort_test_session.cc | 14 ++- .../test/providers/qnn/qnn_basic_test.cc | 56 +++++++++++- 12 files changed, 292 insertions(+), 118 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index aca9f4896fbdb..101a578ec3e1d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3608,6 +3608,14 @@ struct OrtApi { * - "1": Faster preparation time, less optimal graph. * - "2": Longer preparation time, more optimal graph. * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. + * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). + * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options: + * - "0": Default (none). + * - "68" + * - "69" + * - "73" + * - "75" + * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 193e4f5ff2a31..973b81d337c81 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -17,6 +17,7 @@ #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" +#include "core/providers/qnn/builder/qnn_configs_helper.h" #ifdef _WIN32 #include @@ -329,9 +330,37 @@ Status QnnBackendManager::CreateDevice() { return Status::OK(); } + qnn::QnnConfigsBuilder device_configs_builder(QNN_DEVICE_CONFIG_INIT, + {}); + if (qnn_backend_type_ == QnnBackendType::HTP) { + // Set SoC Model. The *enum* Qnn_SocModel_t is deprecated and will not be updated in the future. Therefore, + // must use the latest SDK documentation to get the SoC model of the latest HW. + if (soc_model_ != QNN_SOC_MODEL_UNKNOWN) { + QnnHtpDevice_CustomConfig_t& custom_config = device_configs_builder.PushCustomConfig(); + custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; + custom_config.socModel = soc_model_; + + QnnDevice_Config_t& device_config = device_configs_builder.PushConfig(); + device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + device_config.customConfig = &custom_config; + } + + // Set the minimum HTP architecture. The driver will use ops that are compatible with this minimum architecture. + if (htp_arch_ != QNN_HTP_DEVICE_ARCH_NONE) { + QnnHtpDevice_CustomConfig_t& custom_config = device_configs_builder.PushCustomConfig(); + custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_ARCH; + custom_config.arch.arch = htp_arch_; + custom_config.arch.deviceId = device_id_; + + QnnDevice_Config_t& device_config = device_configs_builder.PushConfig(); + device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + device_config.customConfig = &custom_config; + } + } + LOGS_DEFAULT(INFO) << "Create device."; if (nullptr != qnn_interface_.deviceCreate) { - auto result = qnn_interface_.deviceCreate(log_handle_, nullptr, &device_handle_); + auto result = qnn_interface_.deviceCreate(log_handle_, device_configs_builder.GetQnnConfigs(), &device_handle_); if (QNN_SUCCESS != result) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create device. Error: ", result); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 58f207efb9e95..f7b8947ab84bb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -17,6 +17,7 @@ #include #include "HTP/QnnHtpDevice.h" #include "QnnLog.h" +#include "QnnTypes.h" #include "System/QnnSystemInterface.h" #include "core/common/status.h" #include "core/common/logging/logging.h" @@ -35,13 +36,19 @@ class QnnBackendManager { uint32_t rpc_control_latency, HtpPerformanceMode htp_performance_mode, ContextPriority context_priority, - std::string&& qnn_saver_path) + std::string&& qnn_saver_path, + uint32_t device_id, + QnnHtpDevice_Arch_t htp_arch, + uint32_t soc_model) : backend_path_(backend_path), profiling_level_(profiling_level), rpc_control_latency_(rpc_control_latency), htp_performance_mode_(htp_performance_mode), context_priority_(context_priority), - qnn_saver_path_(qnn_saver_path) { + qnn_saver_path_(qnn_saver_path), + device_id_(device_id), + htp_arch_(htp_arch), + soc_model_(soc_model) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager); @@ -233,6 +240,9 @@ class QnnBackendManager { #endif const std::string qnn_saver_path_; uint32_t htp_power_config_client_id_ = 0; + uint32_t device_id_ = 0; + QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; + uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h new file mode 100644 index 0000000000000..9dd9bbaa08d64 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace qnn { + +/** + * Helper class for building a null-terminated list of QNN configurations. + * A QNN configuration consists of multiple objects with references to each other. This + * class ensures that all configuration objects have the same lifetime, so that they remain valid + * across calls to qnn_interface.xxxCreate(). + */ +template +class QnnConfigsBuilder { + public: + /** + * Initializes the config build. Provide the initial/default value for each config struct type. + * \param base_config_init The initial/default value for objects of type BaseConfigType. + * \param custom_config_init The initial/default value for objects of type CustomConfigType. + */ + QnnConfigsBuilder(BaseConfigType base_config_init, CustomConfigType custom_config_init) + : base_config_init_(std::move(base_config_init)), custom_config_init_(std::move(custom_config_init)) {} + + /** + * Returns a pointer to the beginning of a null-terminated array of QNN base configurations. + * This result is typically passed to QNN's xxxCreate() APIs. + * + * \return Pointer to null-terminated BaseConfigType* array. + */ + const BaseConfigType** GetQnnConfigs() { + if (config_ptrs_.empty()) { + return nullptr; + } + + if (!IsNullTerminated()) { + config_ptrs_.push_back(nullptr); + } + + return config_ptrs_.data(); + } + + /** + * Creates and returns a reference to a new custom QNN configuration object. The object is initialized to + * the QNN recommended default value. The caller is meant to override fields in this object. + * + * \return A reference to a default CustomConfigType object. + */ + CustomConfigType& PushCustomConfig() { + custom_configs_.push_back(custom_config_init_); + return custom_configs_.back(); + } + + /** + * Creates and returns a reference to a new QNN configuration object. The object is initialized to + * the QNN recommended default value. The caller is meant to override fields in this object. + * + * \return A reference to a default BaseConfigType object. + */ + BaseConfigType& PushConfig() { + configs_.push_back(base_config_init_); + BaseConfigType& config = configs_.back(); + + // Add pointer to this new config to the list of config pointers. + if (IsNullTerminated()) { + config_ptrs_.back() = &config; // Replace last nullptr entry. + } else { + config_ptrs_.push_back(&config); + } + + return config; + } + + private: + bool IsNullTerminated() const { + return !config_ptrs_.empty() && config_ptrs_.back() == nullptr; + } + + BaseConfigType base_config_init_; + CustomConfigType custom_config_init_; + InlinedVector custom_configs_; + InlinedVector configs_; + InlinedVector config_ptrs_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc b/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc deleted file mode 100644 index 63aa01b48e7e2..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/qnn/builder/qnn_graph_configs_helper.h" - -#include "HTP/QnnHtpGraph.h" - -namespace onnxruntime { -namespace qnn { - -const QnnGraph_Config_t** QnnGraphConfigsBuilder::GetQnnGraphConfigs() { - if (graph_config_ptrs_.empty()) { - return nullptr; - } - - if (!IsNullTerminated()) { - graph_config_ptrs_.push_back(nullptr); - } - - return graph_config_ptrs_.data(); -} - -QnnHtpGraph_CustomConfig_t& QnnGraphConfigsBuilder::PushHtpGraphCustomConfig() { - htp_custom_graph_configs_.push_back(QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); - return htp_custom_graph_configs_.back(); -} - -QnnGraph_Config_t& QnnGraphConfigsBuilder::PushGraphConfig() { - graph_configs_.push_back(QNN_GRAPH_CONFIG_INIT); - QnnGraph_Config_t& config = graph_configs_.back(); - - // Add pointer to this new graph config to the list of graph config pointers. - if (IsNullTerminated()) { - graph_config_ptrs_.back() = &config; // Replace last nullptr entry. - } else { - graph_config_ptrs_.push_back(&config); - } - - return config; -} - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h b/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h deleted file mode 100644 index 8c4928fdacbc4..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "HTP/QnnHtpGraph.h" - -namespace onnxruntime { -namespace qnn { - -/** - * Helper class for building a null-terminated list of QNN Graph configurations. - * A QNN configuration consists of multiple objects with references to each other. This - * class ensures that all configuration objects have the same lifetime, so that they remain valid - * across the call to graphCreate(). - */ -class QnnGraphConfigsBuilder { - public: - /** - * Returns a pointer to the beginning of a null-terminated array of QNN Graph configurations. - * This result is passed QNN's graphCreate() API. - * - * \return Pointer to null-terminated QnnGraph_Config_t* array. - */ - const QnnGraph_Config_t** GetQnnGraphConfigs(); - - /** - * Creates and returns a reference to a new HTP graph configuration object. The object is initialized to - * the QNN recommended default value. The caller is meant to override fields in this object. - * - * \return A reference to a default QnnHtpGraph_CustomConfig_t object. - */ - QnnHtpGraph_CustomConfig_t& PushHtpGraphCustomConfig(); - - /** - * Creates and returns a reference to a new graph configuration object. The object is initialized to - * the QNN recommended default value. The caller is meant to override fields in this object. - * - * \return A reference to a default QnnGraph_Config_t object. - */ - QnnGraph_Config_t& PushGraphConfig(); - - private: - bool IsNullTerminated() const { - return !graph_config_ptrs_.empty() && graph_config_ptrs_.back() == nullptr; - } - - InlinedVector htp_custom_graph_configs_; - InlinedVector graph_configs_; - InlinedVector graph_config_ptrs_; -}; - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 56eb1f4f59f33..0310cc2bc8f26 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -111,6 +111,22 @@ void QNNExecutionProvider::ParseHtpGraphFinalizationOptimizationMode(const std:: } } +static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevice_Arch_t& qnn_htp_arch) { + if (htp_arch_string.empty() || htp_arch_string == "0") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_NONE; + } else if (htp_arch_string == "68") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V68; + } else if (htp_arch_string == "69") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V69; + } else if (htp_arch_string == "73") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V73; + } else if (htp_arch_string == "75") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V75; + } else { + LOGS_DEFAULT(WARNING) << "Invalid HTP architecture: " << htp_arch_string; + } +} + QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options) : IExecutionProvider{onnxruntime::kQnnExecutionProvider, true} { @@ -223,13 +239,49 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } + static const std::string QNN_DEVICE_ID = "device_id"; + uint32_t device_id = 0; + auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); + if (dev_id_pos != provider_options_map.end()) { + int value = std::stoi(dev_id_pos->second); + if (value < 0) { + LOGS_DEFAULT(WARNING) << "Invalid device ID '" << value + << "', only >= 0 allowed. Set to " << device_id << "."; + } else { + device_id = static_cast(value); + } + } + + static const std::string QNN_HTP_ARCH = "htp_arch"; + QnnHtpDevice_Arch_t htp_arch = QNN_HTP_DEVICE_ARCH_NONE; + auto htp_arch_pos = provider_options_map.find(QNN_HTP_ARCH); + if (htp_arch_pos != provider_options_map.end()) { + ParseHtpArchitecture(htp_arch_pos->second, htp_arch); + } + + static const std::string QNN_SOC_MODEL = "soc_model"; + uint32_t soc_model = QNN_SOC_MODEL_UNKNOWN; + auto soc_model_pos = provider_options_map.find(QNN_SOC_MODEL); + if (soc_model_pos != provider_options_map.end()) { + int value = std::stoi(soc_model_pos->second); + if (value < 0) { + LOGS_DEFAULT(WARNING) << "Invalid SoC Model '" << value + << "', only >= 0 allowed. Set to " << soc_model << "."; + } else { + soc_model = static_cast(value); + } + } + qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level, rpc_control_latency, htp_performance_mode, context_priority, - std::move(qnn_saver_path)); + std::move(qnn_saver_path), + device_id, + htp_arch, + soc_model); } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, @@ -512,25 +564,25 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod return Status::OK(); } -void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_builder) const { +void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const { if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP) { if (htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) { - QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushHtpGraphCustomConfig(); + QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushCustomConfig(); htp_graph_opt_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; htp_graph_opt_config.optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; htp_graph_opt_config.optimizationOption.floatValue = static_cast(htp_graph_finalization_opt_mode_); - QnnGraph_Config_t& graph_opt_config = configs_builder.PushGraphConfig(); + QnnGraph_Config_t& graph_opt_config = configs_builder.PushConfig(); graph_opt_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; graph_opt_config.customConfig = &htp_graph_opt_config; } if (vtcm_size_in_mb_ > 0) { - QnnHtpGraph_CustomConfig_t& htp_graph_opt_config_vtcm = configs_builder.PushHtpGraphCustomConfig(); + QnnHtpGraph_CustomConfig_t& htp_graph_opt_config_vtcm = configs_builder.PushCustomConfig(); htp_graph_opt_config_vtcm.option = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE; htp_graph_opt_config_vtcm.vtcmSizeInMB = static_cast(vtcm_size_in_mb_); - QnnGraph_Config_t& graph_opt_config_vtcm = configs_builder.PushGraphConfig(); + QnnGraph_Config_t& graph_opt_config_vtcm = configs_builder.PushConfig(); graph_opt_config_vtcm.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; graph_opt_config_vtcm.customConfig = &htp_graph_opt_config_vtcm; } @@ -547,10 +599,11 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); - qnn::QnnGraphConfigsBuilder graph_configs_builder; + qnn::QnnConfigsBuilder graph_configs_builder(QNN_GRAPH_CONFIG_INIT, + QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); InitQnnGraphConfigs(graph_configs_builder); - ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, graph_configs_builder.GetQnnGraphConfigs())); + ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, graph_configs_builder.GetQnnConfigs())); ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs()); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index d4927f3fa505e..3f75be0efebcd 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -5,11 +5,12 @@ #include "core/framework/execution_provider.h" #include "core/framework/session_options.h" +#include "core/graph/model.h" #include #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" -#include "core/providers/qnn/builder/qnn_graph_configs_helper.h" -#include "core/graph/model.h" +#include "core/providers/qnn/builder/qnn_configs_helper.h" +#include "HTP/QnnHtpGraph.h" namespace onnxruntime { @@ -58,7 +59,7 @@ class QNNExecutionProvider : public IExecutionProvider { void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string); - void InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_holder) const; + void InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const; private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 7e0a811b7d07c..aca609cf94270 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -60,6 +60,10 @@ void usage() { "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" "\t '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" + "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -483,7 +487,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (supported_profiling_level.find(value) == supported_profiling_level.end()) { ORT_THROW("Supported profiling_level: off, basic, detailed"); } - } else if (key == "rpc_control_latency" || key == "vtcm_mb") { + } else if (key == "rpc_control_latency" || key == "vtcm_mb" || key == "soc_model" || key == "device_id") { // no validation } else if (key == "htp_performance_mode") { std::set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", @@ -512,10 +516,20 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } + } else if (key == "htp_arch") { + std::unordered_set supported_htp_archs = {"0", "68", "69", "73", "75"}; + if (supported_htp_archs.find(value) == supported_htp_archs.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + ORT_THROW("Wrong value for htp_arch. select from: " + str); + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', -'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', +'soc_model', 'htp_arch', 'device_id'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index ef04e2be8fd29..6c1d447c7b3a3 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -78,6 +78,10 @@ namespace perftest { "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" "\t '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" + "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [Usage]: -e -i '| |'\n\n" "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index f8a012af5bb13..6854a2649060a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -343,7 +343,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (supported_profiling_level.find(value) == supported_profiling_level.end()) { ORT_THROW("Supported profiling_level: off, basic, detailed"); } - } else if (key == "rpc_control_latency" || key == "vtcm_mb") { + } else if (key == "rpc_control_latency" || key == "vtcm_mb" || key == "soc_model" || key == "device_id") { // no validation } else if (key == "htp_performance_mode") { std::set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", @@ -372,10 +372,20 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) { ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); } + } else if (key == "htp_arch") { + std::unordered_set supported_htp_archs = {"0", "68", "69", "73", "75"}; + if (supported_htp_archs.find(value) == supported_htp_archs.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + ORT_THROW("Wrong value for htp_arch. select from: " + str); + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', -'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model', +'htp_arch', 'device_id'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index bc40682cf87b7..c50b1002fa8c8 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -176,7 +176,10 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { // types and shapes. static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bool enable_qnn_saver = false, std::string htp_graph_finalization_opt_mode = "", - std::string qnn_context_priority = "") { + std::string qnn_context_priority = "", + std::string soc_model = "", + std::string htp_arch = "", + std::string device_id = "") { Ort::SessionOptions so; // Ensure all type/shape inference warnings result in errors! @@ -205,6 +208,18 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo options["qnn_context_priority"] = std::move(qnn_context_priority); } + if (!soc_model.empty()) { + options["soc_model"] = std::move(soc_model); + } + + if (!htp_arch.empty()) { + options["htp_arch"] = std::move(htp_arch); + } + + if (!device_id.empty()) { + options["device_id"] = std::move(device_id); + } + so.AppendExecutionProvider("QNN", options); Ort::Session session(*ort_env, ort_model_path, so); @@ -519,6 +534,45 @@ TEST_F(QnnHTPBackendTests, HTPGraphFinalizationOptimizationModes) { } } +// Test that models run with various SoC model values +TEST_F(QnnHTPBackendTests, HTPSocModels) { + constexpr std::array soc_models = { "", // No explicit SoC model specified + "0", // "Unknown" +#if defined(_M_ARM64) + "37" }; // SC8280X +#elif defined(__linux__) + "30" }; // SM8350 +#else + "" }; +#endif + + for (auto soc_model : soc_models) { + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", + true, // use_htp + false, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority + soc_model); + } +} + +// Test that models run with various HTP architecture values (and set device_id) +TEST_F(QnnHTPBackendTests, HTPArchValues) { + constexpr std::array htp_archs = {"", // No explicit arch specified + "0", // "None" + "68"}; // v68 + for (auto htp_arch : htp_archs) { + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", + true, // use_htp + false, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority + "", // soc_model + htp_arch, // htp_arch + "0"); // device_id + } +} + // Test that models run with high QNN context priority. TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", From 780acda7b4f044564e1f222901fd6a676aa05cbf Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Tue, 23 Jan 2024 06:02:56 +0800 Subject: [PATCH 06/45] Add Big models pipeline (#19222) ### Description 2 models are added in CI. Stabe diffusion Model stage is based on https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md LLama2 FP16 is based on https://github.com/microsoft/Llama-2-Onnx. 12G GPU memory is not enough, so I choose T4 to run it. ### Motivation and Context Add regular E2E test for big models. It will be triggered in main build, that is, it'll run after one PR is merged. More models will be added later. ### Test Runs ### https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1275191&view=results --- .../azure-pipelines/bigmodels-ci-pipeline.yml | 259 ++++++++++++++++++ 1 file changed, 259 insertions(+) create mode 100644 tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml new file mode 100644 index 0000000000000..ff2e7c0468a21 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -0,0 +1,259 @@ +# reference: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +parameters: +- name: specificArtifact + displayName: Use Specific Artifact + type: boolean + default: false +- name: BuildId + displayName: Specific Artifact's RunId + type: number + default: 0 + +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + + - repository: LLaMa2Onnx + type: Github + endpoint: Microsoft + name: Microsoft/Llama-2-Onnx + ref: main + +variables: + - template: templates/common-variables.yml + - name: docker_base_image + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + - name: linux_trt_version + value: 8.6.1.6-1.cuda11.8 + +stages: +- stage: Build_Onnxruntime_Cuda + jobs: + - job: Linux_Build + timeoutInMinutes: 120 + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + workspace: + clean: all + pool: onnxruntime-Ubuntu2204-AMD-CPU + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) + --build-arg BUILD_UID=$( id -u ) + " + Repository: onnxruntimecuda11build + + - task: Cache@2 + inputs: + key: '"ccache" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' + path: $(CCACHE_DIR) + restoreKeys: | + "ccache" | "$(Build.SourceBranch)" + "ccache" + cacheHitVar: CACHE_RESTORED + displayName: Cach Task + + - script: | + sudo mkdir -p $(Pipeline.Workspace)/ccache + condition: ne(variables.CACHE_RESTORED, 'true') + displayName: Create Cache Dir + + - task: CmdLine@2 + inputs: + script: | + mkdir -p $HOME/.onnx + docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume $(Pipeline.Workspace)/ccache:/cache \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + -e CCACHE_DIR=/cache \ + onnxruntimecuda11build \ + /bin/bash -c " + set -ex; \ + env; \ + ccache -s; \ + /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --cmake_generator Ninja \ + --config Release --update --build \ + --skip_submodule_sync \ + --build_shared_lib \ + --parallel \ + --build_wheel \ + --enable_onnx_tests --use_cuda --cuda_version=${{variables.common_cuda_version}} --cuda_home=/usr/local/cuda-${{variables.common_cuda_version}} --cudnn_home=/usr/local/cuda-${{variables.common_cuda_version}} \ + --enable_cuda_profiling --enable_cuda_nhwc_ops \ + --enable_pybind --build_java \ + --use_cache \ + --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=75;86' ; \ + ccache -sv; \ + ccache -z" + workingDirectory: $(Build.SourcesDirectory) + + - task: CmdLine@2 + inputs: + script: | + rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 + rm -f $(Build.BinariesDirectory)/Release/models + find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete + cd $(Build.BinariesDirectory)/Release + find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt + + - script: | + set -ex + mkdir -p $(Agent.TempDirectory)/ort + cp $(Build.BinariesDirectory)/Release/dist/*.whl $(Agent.TempDirectory)/ort/ + displayName: 'Copy Wheels' + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline Artifact' + inputs: + artifactName: 'drop-ort-linux-gpu' + targetPath: '$(Agent.TempDirectory)/ort' + + - template: templates/explicitly-defined-final-tasks.yml + +- stage: Stale_Diffusion + dependsOn: + - Build_Onnxruntime_Cuda + jobs: + - job: Stale_Diffusion + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + workspace: + clean: all + pool: onnxruntime-Linux-GPU-A10-12G + steps: + - checkout: self + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Onnxruntime Artifact' + ArtifactName: 'drop-ort-linux-gpu' + TargetPath: '$(Build.BinariesDirectory)/Release' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - script: | + docker run --rm --gpus all -v $PWD:/workspace -v $(Build.BinariesDirectory)/Release:/Release nvcr.io/nvidia/pytorch:22.11-py3 \ + bash -c " + set -ex; \ + python3 --version; \ + python3 -m pip install --upgrade pip; \ + python3 -m pip install /Release/*.whl; \ + pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion; \ + python3 -m pip install -r requirements-cuda11.txt; \ + python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com; \ + echo Generate an image guided by a text prompt; \ + python3 demo_txt2img.py "astronaut riding a horse on mars"; \ + echo Generate an image with Stable Diffusion XL guided by a text prompt; \ + python3 demo_txt2img_xl.py 'starry night over Golden Gate Bridge by van gogh'; \ + python3 demo_txt2img_xl.py --enable-refiner 'starry night over Golden Gate Bridge by van gogh'; \ + echo Generate an image guided by a text prompt using LCM LoRA; \ + python3 demo_txt2img_xl.py --scheduler LCM --lora-weights latent-consistency/lcm-lora-sdxl --denoising-steps 4 "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"; \ + popd; \ + " + displayName: 'Run stable diffusion demo' + workingDirectory: $(Build.SourcesDirectory) + +- stage: Llama2_ONNX_FP16 + dependsOn: + - Build_Onnxruntime_Cuda + jobs: + - job: Llama2_ONNX_FP16 + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: onnxruntime-Linux-GPU-T4 + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - checkout: LLaMa2Onnx + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Onnxruntime Artifact' + ArtifactName: 'drop-ort-linux-gpu' + TargetPath: '$(Build.BinariesDirectory)/ort-artifact/' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - task: DownloadPackage@1 + displayName: 'Download Llama2 model' + inputs: + packageType: upack + feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' + version: 1.0.0 + definition: '772ebce3-7e06-46d5-b3cc-82040ec4b2ce' + downloadPath: $(Agent.TempDirectory)/llama2_onnx_ft16 + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: onnxruntime/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 + Context: onnxruntime/tools/ci_build/github/linux/docker/ + ScriptName: onnxruntime/tools/ci_build/get_docker_image.py + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimeubi8packagestest + UpdateDepsTxt: false + + - script: | + docker run --rm --gpus all -v $(Build.SourcesDirectory)/Llama-2-Onnx:/workspace \ + -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ + -v $(Agent.TempDirectory)/llama2_onnx_ft16:/models \ + onnxruntimeubi8packagestest \ + bash -c " + set -ex; \ + python3 -m pip install --upgrade pip ; \ + python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ + python3 -m pip install sentencepiece ; \ + pushd /workspace ; \ + python3 MinimumExample/Example_ONNX_LlamaV2.py --onnx_file /models/ONNX/LlamaV2_7B_FT_float16.onnx \ + --embedding_file /models/embeddings.pth --tokenizer_path tokenizer.model --prompt 'What is the lightest element?' > /workspace/answer.txt ; \ + popd ; \ + " + displayName: 'Run Llama2 demo' + workingDirectory: $(Build.SourcesDirectory) + + - script: | + set -ex + real=$(cat $(Build.SourcesDirectory)/Llama-2-Onnx/answer.txt) + trim_actual=$(tr -dc '[[:print:]]' <<< "$real") + expected="The lightest element is hydrogen. Hydrogen is the lightest element on the periodic table, with an atomic mass of 1.00794 u (unified atomic mass units)." + [ "$expected" == "$trim_actual" ] && exit 0 || exit 1 + displayName: 'Check result' From 77da2ef278a4e77cca4cef4e5d72ed1ef46fcce3 Mon Sep 17 00:00:00 2001 From: snadampal <87143774+snadampal@users.noreply.github.com> Date: Mon, 22 Jan 2024 16:43:06 -0600 Subject: [PATCH 07/45] [aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfloat16 (#17031) ### Description This PR adds SbgemmKernel for aarch64. This includes Sbegmm kernel to implement matrix multiplication with bfloat16 SIMD instructions (bfmmla) and MatMul operator changes to invoke the Sbgemm kernel. To enable Sbgemm kernel, set the following session option: "kOrtSessionOptionsGemmFastMathMode" The PR also adds new test cases for mlas and ort. ### Motivation and Context This is to improve MatMul performance on aarch64 platform. I have run the below benchmarking script (bert , roberta and gpt2 model inference) on AWS Graviton3 based c7g.4xl instance and observed 1.2x -1.76x performance improvement compared to sgemm (fp32) kernel performance. ``` cd onnxruntime/python/tools/transformers python3 benchmark.py ``` And the unit test precision results are matching to sgemm kernel results. `./build.sh --config RelWithDebInfo --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync ` --- cmake/onnxruntime_mlas.cmake | 4 + .../onnxruntime_session_options_config_keys.h | 8 +- onnxruntime/core/common/cpuid_info.cc | 7 + onnxruntime/core/common/cpuid_info.h | 2 + onnxruntime/core/mlas/inc/mlas.h | 113 +++ .../core/mlas/lib/aarch64/SbgemmKernelNeon.S | 907 ++++++++++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 25 + onnxruntime/core/mlas/lib/platform.cpp | 6 + onnxruntime/core/mlas/lib/sbgemm.h | 399 ++++++++ .../core/mlas/lib/sbgemm_kernel_neon.cpp | 362 +++++++ onnxruntime/core/providers/cpu/math/matmul.cc | 106 +- onnxruntime/core/providers/cpu/math/matmul.h | 15 + .../test/mlas/unittest/test_sbgemm.cpp | 141 +++ onnxruntime/test/mlas/unittest/test_sbgemm.h | 281 ++++++ .../qdq_transformer_fastmath_test.cc | 730 ++++++++++++++ .../cpu/math/matmul_fastmath_test.cc | 305 ++++++ onnxruntime/test/util/compare_ortvalue.cc | 80 ++ 17 files changed, 3473 insertions(+), 18 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S create mode 100644 onnxruntime/core/mlas/lib/sbgemm.h create mode 100644 onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_sbgemm.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_sbgemm.h create mode 100644 onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc create mode 100644 onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f89d2150a6830..17de2aa4aaea6 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -355,19 +355,23 @@ else() ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 8fd51962bf087..b282438795eb5 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // Flag to specify whether to dump the EP context into the Onnx model. // "0": dump the EP context into separate file, keep the file name in the Onnx model. // "1": dump the EP context into the Onnx model. (default). -static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; \ No newline at end of file +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; + +// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. +// Option values: +// - "0": Gemm FastMath mode is not enabled. [DEFAULT] +// - "1": Gemm FastMath mode is enabled. +static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index fcf9c2b03dea5..711fd595e90fd 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -30,6 +30,10 @@ #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #endif // ARM #endif // Linux @@ -148,6 +152,7 @@ void CPUIDInfo::ArmLinuxInit() { has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -177,6 +182,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); #endif } @@ -278,6 +284,7 @@ void CPUIDInfo::ArmWindowsInit() { /* TODO: implement them when hw+sw is available for testing these features */ has_arm_neon_i8mm_ = false; has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index a15c75104b83a..2f8041e39f680 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -30,6 +30,7 @@ class CPUIDInfo { bool HasArmNeonDot() const { return has_arm_neon_dot_; } bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } uint32_t GetCurrentCoreIdx() const; @@ -125,6 +126,7 @@ class CPUIDInfo { bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bdd4dba521eba..ce7838556fbf0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1614,6 +1614,119 @@ MlasHalfGemmConvertPackB( void* PackedB ); +#if defined(__aarch64__) && defined(__linux__) +/** + * @brief Whether current CPU supports Bfloat16(bf16) acceleration. + */ +bool MLASCALL +MlasBf16AccelerationSupported(); + +/** + * @brief Interface for bf16 gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from single precision to precision, etc. + * + * SBGEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. + */ +class MLAS_SBGEMM_POSTPROCESSOR +{ + public: + virtual void Process(float*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} +}; + +/** + * @brief bfloat16 precision activation functions, with optional sum tensor. + * Supplied sum tensor must be the same layout as the GEMM output tensor. + * And the supplied sum tensor will be added to the tensor before activation. + */ +class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR +{ + public: + MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) + : Activation_(Activation), SumBuf_(SumBuf) + { + } + + void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) + const override; + + private: + const MLAS_ACTIVATION& Activation_; + const float* SumBuf_; +}; + +/** + * @brief Data parameters for bfloat16 precision GEMM routine + * All except C are [in] parameters + */ +struct MLAS_SBGEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ +}; + +/** + * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias + * Either B can be either fp32 or bf16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); + +/** + * @brief For bfloat16 precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported + */ +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K); + +/** + * @brief For bfloat16 precision GEMM, convert the float matrix B + * to blfoat16 precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + */ +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +#endif + /** * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input diff --git a/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S new file mode 100644 index 0000000000000..e424c30515e9f --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S @@ -0,0 +1,907 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + SbgemmKernelNeon.s + +Abstract: + + This module implements the kernels for the bfloat16 half precision matrix/matrix + multiply operation (SBGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the sbgemm kernel. d8-d15, x19-x30 need save +// + .equ .LMlasSbgemmKernel_backup_x19_x20, 0 + .equ .LMlasSbgemmKernel_backup_x21_x22, 16 + .equ .LMlasSbgemmKernel_backup_x23_x24, 32 + .equ .LMlasSbgemmKernel_backup_x25_x26, 48 + .equ .LMlasSbgemmKernel_backup_x27_x28, 64 + .equ .LMlasSbgemmKernel_backup_d8_d9, 80 + .equ .LMlasSbgemmKernel_backup_d10_d11, 96 + .equ .LMlasSbgemmKernel_backup_d12_d13, 112 + .equ .LMlasSbgemmKernel_backup_d14_d15, 128 + .equ .LMlasSbgemmKernel_SavedRegisters, 144 + .equ .LMlasSbgemmKernel_SavedRegisters_Neg, -144 + + +// +// ClearRowAccumulators +// +// Generates the code to clear the accumulators for a single row of the output +// block. +// + + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + mov v\Vec1Reg\().16b,v0.16b +.if \Columns\() > 2 + mov v\Vec2Reg\().16b,v1.16b +.endif +.if \Columns\() > 4 + mov v\Vec3Reg\().16b,v2.16b +.endif +.if \Columns\() > 6 + mov v\Vec4Reg\().16b,v3.16b +.endif + + .endm + +// +// InitBlockAccumulators +// +// Generates the code to init the accumulators for a single row of the output +// block. +// + + .macro InitBlockAccumulators Mode, Columns, Rows + + //check if the Bias != nullptr + cbz x8,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd + + ld1 {v14.4s},[x8],#16 // load Bias[0] + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast Bias + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + ld1 {v15.4s},[x8],#16 // load Bias[4] + dup v4.4s,v15.s[0] // broadcast Bias + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + b .L\Mode\().PopulateAccumulators\Columns\().x\Rows\() + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd: + eor v0.16b,v0.16b,v0.16b // No bias, reset regs + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + +.L\Mode\().PopulateAccumulators\Columns\().x\Rows\(): + InitRowAccumulators \Columns\(),16,17,18,19 +.if \Rows\() > 2 + InitRowAccumulators \Columns\(),20,21,22,23 +.endif +.if \Rows\() > 4 + InitRowAccumulators \Columns\(),24,25,26,27 +.endif +.if \Rows\() > 6 + InitRowAccumulators \Columns\(),28,29,30,31 +.endif + + .endm + +// LoadMatrixAElementsBy8 +// +// Generates the code to load 4 or 8 elements from matrix A. +// + .macro LoadMatrixAElementsBy8 Rows + + ldr q8,[x0],#16 + bfcvtn v8.4h, v8.4s +.if \Rows\() > 1 + ldr q1,[x10],#16 + bfcvtn2 v8.8h, v1.4s +.endif + +.if \Rows\() > 2 + ldr q9,[x11],#16 + bfcvtn v9.4h, v9.4s +.endif +.if \Rows\() > 3 + ldr q1,[x12],#16 + bfcvtn2 v9.8h, v1.4s +.endif + +.if \Rows\() > 4 + ldr q10,[x20],#16 + bfcvtn v10.4h, v10.4s +.endif +.if \Rows\() > 5 + ldr q1,[x21],#16 + bfcvtn2 v10.8h, v1.4s +.endif + +.if \Rows\() > 6 + ldr q11,[x22],#16 + bfcvtn v11.4h, v11.4s +.endif +.if \Rows\() > 7 + ldr q1,[x23],#16 + bfcvtn2 v11.8h, v1.4s +.endif + + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + bfmmla v\Vec1Reg\().4s, \MatrixAReg\().8h, v4.8h +.if \Columns\() > 2 + bfmmla v\Vec2Reg\().4s, \MatrixAReg\().8h, v5.8h +.endif +.if \Columns\() > 4 + bfmmla v\Vec3Reg\().4s, \MatrixAReg\().8h, v6.8h +.endif +.if \Columns\() > 6 + bfmmla v\Vec4Reg\().4s, \MatrixAReg\().8h, v7.8h +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(),\Columns\(),\Rows\() + + add x10,x0,x6,lsl #2 // compute matrix A plus 1 row +.if \Rows\() > 2 + add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows + add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows +.endif +.if \Rows\() > 4 + add x20,x12,x6,lsl #2 // compute matrix A plus 4 rows + add x21,x20,x6,lsl #2 // compute matrix A plus 5 rows +.endif +.if \Rows\() > 6 + add x22,x21,x6,lsl #2 // compute matrix A plus 6 rows + add x23,x22,x6,lsl #2 // compute matrix A plus 7 rows +.endif + sub x9,x3,#4 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#4 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#4 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + fadd v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + fadd v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x26 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),8,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),4,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSbgemmCopyPackB or MlasSbgemmTransposePackB. + + C (x2) - Supplies the address of matrix C. + + CountK (x3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (x6) - Supplies the first dimension of matrix A. + + ldc (x7) - Supplies the first dimension of matrix C. + + Bias - Supplies the address of Bias Vector [1xn] + + +Return Value: + + Returns the number of rows handled. + +--*/ + .macro SbgemmKernelNeonFunction Mode + + FUNCTION_ENTRY MlasSbgemmKernel\Mode\() + + ldr x8, [sp, #0] //Bias vector + + stp x19, x20, [sp, #.LMlasSbgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + + add x13,x2,x7,lsl #2 // compute matrix C plus 1 row + add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x7,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x7,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x7,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x7,lsl #2 // compute matrix C plus 7 rows + + mov x26,x0 // save matrix A +// +// Process 8 rows of the matrices. +// + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasSbgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + SbgemmKernelNeonFunction Zero + SbgemmKernelNeonFunction Add diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 7bb8b17031a84..624eb913d5c9e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -193,6 +193,8 @@ class MLASCPUIDInfo bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + private: MLASCPUIDInfo(); @@ -200,6 +202,7 @@ class MLASCPUIDInfo bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -357,6 +360,20 @@ size_t #else +#if defined(__aarch64__) && defined(__linux__) +typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( + const float* A, + const bfloat16_t* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + const float* Bias +); +#endif + typedef size_t (MLASCALL MLAS_GEMM_FLOAT_KERNEL)( @@ -727,6 +744,10 @@ extern "C" { #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; +#if defined(__aarch64__) && defined(__linux__) + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; #endif @@ -856,6 +877,10 @@ extern "C" { #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#if defined(__aarch64__) && defined(__linux__) +#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#endif + // // Single-threaded single precision matrix/matrix multiply operation. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 1310ed3f384b9..de092f7d1d350 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -60,6 +60,10 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -70,6 +74,8 @@ MLASCPUIDInfo::MLASCPUIDInfo() has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); } #endif diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h new file mode 100644 index 0000000000000..de7fd72fad45a --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -0,0 +1,399 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm.h + +Abstract: + + This module defines the set of template functions to implement bfloat16 + precision matrix/matrix multiply operation (SBGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasSBGemmConvertPackB + MlasSBGemmPackedBOffset + MlasSBGemmPackedBLeadingDim + MlasSBGemmKernel + + MlasSBGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + size_t PackedN; Packed alignment on the n dim (power of 2) + MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include +#include + +#include "mlasi.h" + +/** + * @brief Define the default striding parameters for + * the bfloat16 precision gemm operation + */ +struct MLAS_SBGEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Convert fp32 matrix B to bf16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] + */ +template +MLAS_FORCEINLINE const bfloat16_t* +MlasSBGemmPackedBOffset( + const bfloat16_t* PackedB, size_t DimN, size_t DimK, size_t StartN, size_t StartK +) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer + */ +template +MLAS_FORCEINLINE size_t +MlasSBGemmPackedBLeadingDim(size_t DimN, size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} + +template +void +MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, const float* A, const size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode); + +template +MLAS_FORCEINLINE void +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t PackedStrideN = Strides.N; + size_t PackedStrideK = Strides.K; + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + const size_t SliceStartN = RangeStartN + n; + CountN = std::min(RangeCountN - n, PackedStrideN); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + bool ZeroMode = (k == 0); + CountK = std::min(K - k, PackedStrideK); + + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + float* c = C + n; + const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor) + ->Process(C + n, M, SliceStartN, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + // + // Compute the strides to step through slices of the input matrices. + // + // Expand the N stride if K is small or expand the K stride if N is small + // for better utilization of the B panel. Avoid changing the K stride if + // the A panel needs to be used for transposing. + // + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t StrideN = Strides.N; + size_t StrideK = Strides.K; + + if (N >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > 16 && StrideN / 2 >= N) { + StrideK *= 2; + StrideN /= 2; + } + } + + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(bfloat16_t)); + MlasThreadedBufAlloc(packBSize); + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelB = reinterpret_cast(p); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < N; n += CountN) { + CountN = std::min(N - n, StrideN); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, StrideK); + + // + // Copy a panel of matrix B to a local packed buffer. + // + MlasSBGemmConvertPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); + + auto* c = C + n; + const float* pbias = + ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart + + bool ZeroMode = (k == 0); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor)->Process(C + n, M, N, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId) +{ + const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; + const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + + RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + // + // Dispatch the partitioned operation. + // + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const float* A = (const float*)DataParams->A + RangeStartM * lda; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* bias = DataParams->Bias; + + if (!DataParams->BIsfp32) { + MlasSBGemmPackedOperation( + RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + ); + } else { + const size_t ldb = DataParams->ldb; + const float* B = (const float*)DataParams->B + RangeStartN; + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + } +} + +// +// dispatch structure. +// +typedef void(MLAS_SBGEMM_OPERATION)(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId); + +typedef void(MLAS_SBGEMM_CONVERTPACKB_ROUTINE)( + bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Hardware dependent dispatch for half precision GEMM + */ +struct MLAS_SBGEMM_DISPATCH { + MLAS_SBGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_SBGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackedK; + size_t PackedN; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon; + +MLAS_FORCEINLINE +const MLAS_SBGEMM_DISPATCH* +MlasSBGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasSBGemmDispatchNeon; +#else + std::cerr << "SBGemm Kernel is supported only on ARM64 platform."; + exit(1); +#endif +} + +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K) +{ + // + // Compute the number of bytes required to hold the packed buffer. + // + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return 0; + + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackedK; + const auto PackedN = dispatch->PackedN; + + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); + const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + + return AlignedBytesRequired; +} + +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + dispatch->ConvertPackBRoutine((bfloat16_t*)PackedB, B, ldb, N, K); +} + +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +{ + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K); + + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_SBGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + + if (N > M) { + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); + } + + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; + + } else { + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); + } + + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; + } + + MlasTrySimpleParallel( + ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); + } + ); +} +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..a6a73996c548b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -0,0 +1,362 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision GEMM kernel for neon. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "arm_neon.h" +#include "mlasi.h" +#include "sbgemm.h" + +struct MLAS_SBGEMM_KERNEL_NEON { + static constexpr bool PackNeeded = true; + static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 4; + static constexpr size_t PackedN = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K +}; + +bool MLASCALL +MlasBf16AccelerationSupported() +{ +#if defined(MLAS_TARGET_ARM64) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16(); +#else + return false; +#endif +} + +/* + This routine converts fp32 to bf16 and copies elements from the source + matrix to the destination packed buffer. + + 4x2 elements from the source matrix are unrolled to be physically + contiguous for better locality inside the SBGEMM kernels. The remaining + rows and columns are padded to 4 and 2 alignment. +*/ +MLAS_FORCEINLINE +void +MlasSBGemmConvertCopyPackB(bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK) +{ + // + // Copy data from matrix B into the destination buffer 4x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const float* b = B; + int y = static_cast(CountK); + + while (y > 0) { + MLAS_FLOAT32X4 t0_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t0_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_h = MlasZeroFloat32x4(); + + if (y >= 4) { + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + t3_l = MlasLoadFloat32x4(&b[ldb * 3]); + t3_h = MlasLoadFloat32x4(&b[ldb * 3 + 4]); + } else { + switch (y) { + case 3: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + break; + case 2: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + break; + case 1: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + break; + } + } + + float32x4x2_t z0_l = vzipq_f32(t0_l, t2_l); + float32x4x2_t z1_l = vzipq_f32(t1_l, t3_l); + float32x4x2_t o0_l = vzipq_f32(z0_l.val[0], z1_l.val[0]); + float32x4x2_t o1_l = vzipq_f32(z0_l.val[1], z1_l.val[1]); + t0_l = o0_l.val[0]; + t1_l = o0_l.val[1]; + t2_l = o1_l.val[0]; + t3_l = o1_l.val[1]; + + bfloat16x8_t t0t1_l_4h = vcvtq_low_bf16_f32(t0_l); + bfloat16x8_t t0t1_l_8h = vcvtq_high_bf16_f32(t0t1_l_4h, t1_l); + + bfloat16x8_t t2t3_l_4h = vcvtq_low_bf16_f32(t2_l); + bfloat16x8_t t2t3_l_8h = vcvtq_high_bf16_f32(t2t3_l_4h, t3_l); + + vst1q_bf16(&D[0], t0t1_l_8h); + vst1q_bf16(&D[8], t2t3_l_8h); + + float32x4x2_t z0_h = vzipq_f32(t0_h, t2_h); + float32x4x2_t z1_h = vzipq_f32(t1_h, t3_h); + float32x4x2_t o0_h = vzipq_f32(z0_h.val[0], z1_h.val[0]); + float32x4x2_t o1_h = vzipq_f32(z0_h.val[1], z1_h.val[1]); + t0_h = o0_h.val[0]; + t1_h = o0_h.val[1]; + t2_h = o1_h.val[0]; + t3_h = o1_h.val[1]; + + bfloat16x8_t t0t1_h_4h = vcvtq_low_bf16_f32(t0_h); + bfloat16x8_t t0t1_h_8h = vcvtq_high_bf16_f32(t0t1_h_4h, t1_h); + + bfloat16x8_t t2t3_h_4h = vcvtq_low_bf16_f32(t2_h); + bfloat16x8_t t2t3_h_8h = vcvtq_high_bf16_f32(t2t3_h_4h, t3_h); + + vst1q_bf16(&D[16], t0t1_h_8h); + vst1q_bf16(&D[24], t2t3_h_8h); + + D += 32; + b += ldb * 4; + y -= 4; + }; + B += 8; + CountN -= 8; + } + + // + // Special case the handling of the remaining columns less than 8 elements + // wide. + // + if (CountN > 0) { + int y = static_cast(CountK); + while (y > 0) { + const float* b = B; + size_t b_inc = 0; + if ((CountN & 4) != 0) { + MLAS_FLOAT32X4 t0 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3 = MlasZeroFloat32x4(); + if (y >= 4) { + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + t3 = MlasLoadFloat32x4(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + break; + case 2: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + break; + case 1: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + break; + } + } + + float32x4x2_t z0 = vzipq_f32(t0, t2); + float32x4x2_t z1 = vzipq_f32(t1, t3); + float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); + float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); + + t0 = o0.val[0]; + t1 = o0.val[1]; + t2 = o1.val[0]; + t3 = o1.val[1]; + + bfloat16x8_t t0t1_4h = vcvtq_low_bf16_f32(t0); + bfloat16x8_t t0t1_8h = vcvtq_high_bf16_f32(t0t1_4h, t1); + + bfloat16x8_t t2t3_4h = vcvtq_low_bf16_f32(t2); + bfloat16x8_t t2t3_8h = vcvtq_high_bf16_f32(t2t3_4h, t3); + + vst1q_bf16(&D[0], t0t1_8h); + vst1q_bf16(&D[8], t2t3_8h); + + D += 16; + b += 4; + b_inc += 4; + } + + if ((CountN & 2) != 0) { + float32x2_t t0 = {0x0, 0x0}; + float32x2_t t1 = {0x0, 0x0}; + float32x2_t t2 = {0x0, 0x0}; + float32x2_t t3 = {0x0, 0x0}; + + if (y >= 4) { + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + t3 = vld1_f32(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + break; + case 2: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + break; + case 1: + t0 = vld1_f32(&b[ldb * 0]); + break; + } + } + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 2; + b_inc += 2; + } + if ((CountN & 1) != 0) { + float a = 0.0f; + float b = 0.0f; + float c = 0.0f; + float d = 0.0f; + + if (y >= 4) { + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + d = *(float*)(&B[ldb * 3 + b_inc]); + } else { + switch (y) { + case 3: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + break; + case 2: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + break; + case 1: + a = *(float*)(&B[ldb * 0 + b_inc]); + break; + } + } + + float32x2_t t0 = {a, 0x0}; + float32x2_t t1 = {b, 0x0}; + float32x2_t t2 = {c, 0x0}; + float32x2_t t3 = {d, 0x0}; + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 1; + b_inc += 1; + } + B += 4 * ldb; + y -= 4; + } + } +} + +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + const auto PackedN = dispatch->PackedN; + + const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t K_block_size; + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + + for (size_t k = 0; k < CountK; k += K_block_size) { + K_block_size = std::min(CountK - k, Strides.K); + + MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); + PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + } +} + +template <> +MLAS_FORCEINLINE void +MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) +{ + while (CountM > 0) { + size_t RowsHandled; + if (ZeroMode) { + RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } else { + RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + } +} + +const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon = { + MlasSBGemmOperation, + MlasSBGemmConvertPackB, + MLAS_SBGEMM_KERNEL_NEON::PackedK, + MLAS_SBGEMM_KERNEL_NEON::PackedN, + MLAS_SBGEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index ec395cf018f5e..583ee759cc2e6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -6,7 +6,6 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" namespace onnxruntime { @@ -125,6 +124,44 @@ Status MatMul::Compute(OpKernelContext* ctx) const { return Status::OK(); } +#if defined(__aarch64__) && defined(__linux__) +bool GemmPackBBfloat16(AllocatorPtr& alloc, + const Tensor& tensor_b, + bool trans_b, + IAllocatorUniquePtr& packed_b, + size_t& packed_b_size, + TensorShape& b_shape) { + // Only handle the common case of a 2D weight matrix. Additional matrices + // could be handled by stacking the packed buffers. + if (tensor_b.Shape().NumDimensions() != 2) { + return false; + } + + b_shape = tensor_b.Shape(); + + const size_t K = trans_b ? static_cast(b_shape[1]) : static_cast(b_shape[0]); + const size_t N = trans_b ? static_cast(b_shape[0]) : static_cast(b_shape[1]); + + packed_b_size = MlasSBGemmPackBSize(N, K); + if (packed_b_size == 0) { + return false; + } + + packed_b = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + auto* packed_b_data = packed_b.get(); + + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we don not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset(packed_b_data, 0, packed_b_size); + MlasSBGemmConvertPackB(N, + K, + tensor_b.Data(), + trans_b ? K : N, + packed_b_data); + return true; +} +#endif Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -134,7 +171,24 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc // only pack Matrix B if (input_idx == 1) { size_t packed_b_size; - is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); +#if defined(__aarch64__) && defined(__linux__) + size_t dim1 = 0; + size_t dim2 = 0; + TensorShape b_shape = tensor.Shape(); + + if (b_shape.NumDimensions() == 2) { + dim1 = static_cast(b_shape[0]); + dim2 = static_cast(b_shape[1]); + } + + if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { + is_packed = GemmPackBBfloat16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + } else +#endif + { + is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + } + bool share_prepacked_weights = (prepacked_weights != nullptr); if (is_packed && share_prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); @@ -186,22 +240,40 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); - - std::vector data(max_len); - for (size_t i = 0; i < max_len; i++) { - data[i].BIsPacked = bool(packed_b_); - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].alpha = alpha_attr_; - data[i].beta = 0.0f; +#if defined(__aarch64__) && defined(__linux__) + if (use_fastmath_mode_ && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsfp32 = !(bool(packed_b_)); + data[i].AIsfp32 = true; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsfp32 ? b_data + helper.RightOffsets()[i] : (float*)packed_b_.get(); + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].Bias = nullptr; + data[i].OutputProcessor = nullptr; + } + MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool); + } else +#endif + { + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = bool(packed_b_); + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = alpha_attr_; + data[i].beta = 0.0f; + } + MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, data.data(), max_len, thread_pool); } - MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, - M, N, K, data.data(), max_len, thread_pool); - return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index b960fa4fb0587..b9bbe36583879 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -4,6 +4,8 @@ #pragma once #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { @@ -27,6 +29,11 @@ class MatMul final : public OpKernel { info.GetAttrOrDefault("transBatchB", &trans_batch_b_attr, 0); trans_batch_a_ = trans_batch_a_attr != 0; trans_batch_b_ = trans_batch_b_attr != 0; + +#if defined(__aarch64__) && defined(__linux__) + auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); + use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported(); +#endif } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, @@ -48,6 +55,14 @@ class MatMul final : public OpKernel { int64_t trans_b_attr_; bool trans_batch_a_; bool trans_batch_b_; + +#if defined(__aarch64__) && defined(__linux__) + // fastmath mode state + bool use_fastmath_mode_; + // sbgemm kernel is implemented as 8x8 blocks with weights pre-packed to 4 blocks of 4x2 + // so a minimum of 32 elements is defined to outweigh the additional prepacking overhead + const size_t kFastMathModeKernelsizeThreshold = 32; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp new file mode 100644 index 0000000000000..941de8f05061f --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -0,0 +1,141 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + test_sbgemm.cpp + +Abstract: + + Tests for MLAS bf16 precision GEMM. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "test_sbgemm.h" + +// +// Short Execute() test helper to register each test seperately by all parameters. +// +template +class SBGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SBGemmShortExecuteTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) + : M_(M), N_(N), K_(K), Batch_(Batch), hasBias_(hasBias) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, hasBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) { + std::stringstream ss; + ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" + << "hasBias" << hasBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSBGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SBGemmShortExecuteTest( + M, N, K, Batch, hasBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, 1, false); + test_registered += RegisterSingleTest(1, 32, b, 1, true); + test_registered += RegisterSingleTest(1, b, b, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(1, b, 32, 3, true); + test_registered += RegisterSingleTest(1, 32, b, 5, false); + } + } + // TODO: check why the cosine similary is < 0.99 for this shape alone + // test_registered += RegisterSingleTest(43, 500, 401, 1, true); + test_registered += RegisterSingleTest(1001, 1027, 1031, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(43, 500, 401, 5, true); + test_registered += RegisterSingleTest(1000, 1029, 1030, 3, false); + } + + return test_registered; + } + + private: + size_t M_, N_, K_, Batch_; + bool hasBias_; +}; + +static size_t SBGemmRegistLongExecute() { + size_t count = 0; + + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + + if (GetMlasThreadPool() != nullptr) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + } + + return count; +} + +static size_t SBGemmRegistShortExecute() { + size_t count = 0; + + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + } + + if (GetMlasThreadPool() != nullptr) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + } + } + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + if (!MlasBf16AccelerationSupported()) { + return false; + } + + if (is_short_execute) { + return SBGemmRegistShortExecute() > 0; + } + return SBGemmRegistLongExecute() > 0; +}); +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.h b/onnxruntime/test/mlas/unittest/test_sbgemm.h new file mode 100644 index 0000000000000..13701e2e3de46 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -0,0 +1,281 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + test_sbgemm.h + +Abstract: + + Tests for MLAS bf16 precision GEMM. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include "test_util.h" + +template +void SmallFloatFill(T* start, size_t size) { + constexpr float MinimumFillValue = -11.0f; + auto FillAddress = start; + size_t offset = size % 23; + + for (size_t i = 0; i < size; i++) { + offset = (offset + 21) % 23; + *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); + } +} + +float cosine_similarity(const float* A, const float* B, size_t Vector_Length) { + float dot = 0.0, denom_a = 0.0, denom_b = 0.0; + for (size_t i = 0u; i < Vector_Length; ++i) { + dot += A[i] * B[i]; + denom_a += A[i] * A[i]; + denom_b += B[i] * B[i]; + } + return dot / (sqrt(denom_a) * sqrt(denom_b)); +} + +/** + * @brief Test class for bf16 precision GEMM + * @tparam AType Data type of A matrix, need to be float + * @tparam BType Data type of b matrix, can be either float or prepacked bf16 + */ +template +class MlasSBGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferBPacked; + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + MatrixGuardBuffer BufferFloatC; + MLAS_THREADPOOL* threadpool_; + + void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { + size_t PackedBSize = MlasSBGemmPackBSize(N, K); + if (PackedBSize == 0) { + return nullptr; + } + void* PackedB = BufferBPacked.GetBuffer(PackedBSize); + if (std::is_same::value) { + MlasSBGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); + } else { + } + return PackedB; + } + + void CallSBGemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const float* A, + size_t lda, + const BType* B, + size_t ldb, + const float* Bias, + float* C, + size_t ldc) { + std::vector GemmParameters(BatchSize); + + for (size_t i = 0; i < GemmParameters.size(); i++) { + auto& params = GemmParameters[i]; + params.A = A + (M * lda * i); + params.lda = lda; + if (nullptr != Bias) { + params.Bias = reinterpret_cast(Bias + N * i); + } else { + params.Bias = nullptr; + } + params.C = reinterpret_cast(C + (M * ldc * i)); + params.ldc = ldc; + params.AIsfp32 = true; + params.BIsfp32 = true; + + if (Packed) { + ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; + params.B = PackB(N, K, B, ldb); + params.ldb = 0; + params.BIsfp32 = false; + } else { + params.B = B + (K * N * i); + params.ldb = ldb; + } + } + + MlasSBGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + } + + void ReferenceSgemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + const BType* B, + const float* Bias, + float* C) { + constexpr size_t KStride = 256; + + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const AType* a = A + M * K * batch + m * K; + const BType* b = B + K * N * batch + n; + float* c = C + (M * N * batch) + (m * N) + n; + + for (size_t k = 0; k < K; k += KStride) { + float sum = 0.0f; + if (k == 0 && Bias != nullptr) { + sum = float(Bias[n]); + } + for (size_t kk = 0; kk < std::min(KStride, K - k); kk++) { + float down(float(*b) * float(*a) + sum); + sum = float(down); + b += N; + a += 1; + } + if (k == 0) { + *c = sum; + } else { + float d(sum + *c); + *c = float(d); + } + } + } + } + if (Bias) { + Bias += N; + } + } + } + + public: + MlasSBGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} + + void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + AType Atail[16]; + std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); + + BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + BType Btail[16]; + std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); + + float BiasTail[16]; + const float* Bias = nullptr; + if (withBias) { + Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)); + } + + float* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + this->CallSBGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N); + ReferenceSgemm(M, N, K, BatchSize, A, B, Bias, CReference); + const float cosine_similarity_threshold = 0.98; + + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + if (!(CloseEnough(float(C[f]), CReference[f]))) { + float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); + if (abs(cos_sim) < cosine_similarity_threshold) { + ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; + } else { + break; + } + } + } + } + } + + ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; + ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; + if (withBias) { + ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)), 0) << "Bias buffer overwritten!"; + } + } + + private: + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SBGemmFP") + + (std::is_same::value ? "32" : "16") + + (std::is_same::value ? "32" : "16") + + (Packed ? "_Packed" : "_NoPack") + + (Threaded ? "_Threaded" : "_SingleThread"); + return suite_name.c_str(); + } + + void ExecuteLong(void) override { + for (size_t M = 16; M < 160; M += 32) { + for (size_t N = 16; N < 160; N += 32) { + static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; + for (size_t k = 0; k < _countof(ks); k++) { + size_t K = ks[k]; + + Test(M, N, K, 1, false); + Test(M, N, K, 1, true); + Test(M + 1, N, K, 1, false); + Test(M, N + 1, K, 1, true); + Test(M + 1, N + 1, K, 1, false); + Test(M + 3, N + 2, K, 1, true); + Test(M + 4, N, K, 1, false); + Test(M, N + 4, K, 1, true); + Test(M + 4, N + 4, K, 1, false); + Test(M + 3, N + 7, K, 1, true); + Test(M + 8, N, K, 1, false); + Test(M, N + 8, K, 1, true); + Test(M + 12, N + 12, K, 1, false); + Test(M + 13, N, K, 1, true); + Test(M, N + 15, K, 1, false); + Test(M + 15, N + 15, K, 1, false); + if (!Packed) { + Test(M, N, K, 7, false); + Test(M + 3, N, K, 8, true); + Test(M, N + 1, K, 9, false); + Test(M + 12, N, K, 10, true); + Test(M, N + 15, K, 11, false); + Test(M + 15, N + 15, K, 12, true); + } + } + } + printf("M %zd\n", M); + } + + for (size_t M = 1; M < 160; M++) { + for (size_t N = 1; N < 160; N++) { + for (size_t K = 1; K < 160; K++) { + Test(M, N, K, 1, true); + } + } + printf("M %zd\n", M); + } + + for (size_t M = 160; M < 320; M += 24) { + for (size_t N = 112; N < 320; N += 24) { + for (size_t K = 1; K < 16; K++) { + Test(M, N, K, 1, true); + } + for (size_t K = 16; K < 160; K += 32) { + Test(M, N, K, 1, false); + } + } + printf("M %zd\n", M); + } + } +}; + +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc new file mode 100644 index 0000000000000..ec9f78da14a75 --- /dev/null +++ b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc @@ -0,0 +1,730 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// Licensed under the MIT License. + +#include "core/framework/compute_capability.h" +#include "core/graph/model.h" +#include "core/graph/onnx_protobuf.h" +#include "core/mlas/inc/mlas.h" +#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/optimizer/utils.h" +#include "core/providers/partitioning_utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/environment.h" +#include "core/session/inference_session.h" + +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/framework/test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +#include "gtest/gtest.h" +#include "graph_transform_test_builder.h" + +#include "qdq_test_utils.h" + +#if defined(__aarch64__) && defined(__linux__) && !defined(DISABLE_CONTRIB_OPS) + +struct QDQOpKeys { + const char* quantize_linear; + const char* dequantize_linear; +}; + +constexpr QDQOpKeys GetQDQOpKeys(bool use_contrib_qdq) { + if (use_contrib_qdq) { + return {"com.microsoft.QuantizeLinear", "com.microsoft.DequantizeLinear"}; + } + return {"QuantizeLinear", "DequantizeLinear"}; +} + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +TEST(QDQTransformerTests, DQ_S8_to_U8_FastMath) { + auto test_case = [](bool use_contrib_qdq) { + const std::vector& input_shape = {19, 37}; + const std::vector& weights_shape = {37, 23}; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input_shape, -1.f, 1.f); + + // Use full range weight values to expose u8s8 overflow problems + auto* weight = builder.MakeInitializer(weights_shape, -128, 127); + auto* output_arg = builder.MakeOutput(); + + // add QDQ activation + typedef std::numeric_limits Input1Limits; + auto* dq1_output = AddQDQNodePair(builder, input1_arg, .039f, + (int8_t)((Input1Limits::max() + Input1Limits::min()) / 2 + 1), + use_contrib_qdq); + + // add DQ weight + auto* dq_w_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(weight, .003f, -10, dq_w_output, use_contrib_qdq); + + builder.AddNode("MatMul", {dq1_output, dq_w_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, /*using NAN as a magic number to trigger cosine similarity*/ + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case(false); // Use ONNX QDQ ops + test_case(true); // Use com.microsoft QDQ ops +} + +template +void QDQTransformerMatMulTests(bool has_output_q, bool disable_fastmath = false) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + typedef std::numeric_limits Input1Limits; + typedef std::numeric_limits Input2Limits; + typedef std::numeric_limits OutputTypeLimits; + + // add QDQ 1 + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input1_arg, + .039f, + (Input1Limits::max() + Input1Limits::min()) / 2 + 1, + q1_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q1_output, + .039f, + (Input2Limits::max() + Input1Limits::min()) / 2 + 1, + dq1_output, use_contrib_qdq); + + // add QDQ 2 + auto* q2_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input2_arg, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + q2_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q2_output, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + dq2_output, use_contrib_qdq); + + if (has_output_q) { + // add binary operator + auto* matmul_op_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", {dq1_output, dq2_output}, {matmul_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(matmul_op_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + q3_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q3_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + output_arg, use_contrib_qdq); + } else { + builder.AddNode("MatMul", {dq1_output, dq2_output}, {output_arg}); + } + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + if (has_output_q) { + if constexpr (std::is_same::value && + (std::is_same::value || + QDQIsInt8Allowed() && std::is_same::value)) { + EXPECT_EQ(op_to_count["QLinearMatMul"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + } else { + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 3); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 3); + } + } else { + if constexpr (std::is_same::value || + (QDQIsInt8Allowed() && std::is_same::value)) { + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + } else { + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2); + } + } + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + + if (disable_fastmath) { + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + } + }; + + test_case({1, 2, 2}, {1, 2, 4}); + test_case({1, 23, 13, 13}, {13, 13}); + test_case({1, 22, 11, 13, 15}, {1, 22, 11, 15, 15}); + test_case({1, 2, 2}, {1, 2, 4}, true); // Use com.microsoft QDQ ops +} + +TEST(QDQTransformerTests, MatMul_U8U8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8S8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8U8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8S8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8S8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8U8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8U8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8S8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +// dummy test to disable the fastmath session op +TEST(QDQTransformerTests, MatMul_S8S8U8_DisableFastMath) { + QDQTransformerMatMulTests(false, true); + QDQTransformerMatMulTests(true, true); +} + +template +void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false, bool disable_fastmath = false) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + typedef std::numeric_limits Input1Limits; + typedef std::numeric_limits Input2Limits; + typedef std::numeric_limits OutputTypeLimits; + + std::vector input_args; + + // add QDQ A + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input1_arg, + .039f, + (Input1Limits::max() + Input1Limits::min()) / 2 + 1, + q1_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q1_output, + .039f, + (Input2Limits::max() + Input1Limits::min()) / 2 + 1, + dq1_output, use_contrib_qdq); + + input_args.push_back(dq1_output); + + // add QDQ B + auto* q2_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input2_arg, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + q2_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q2_output, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + dq2_output, use_contrib_qdq); + input_args.push_back(dq2_output); + + if (has_bias) { + auto* dq_bias_output = builder.MakeIntermediate(); + auto* bias = builder.MakeInitializer({input2_shape[1]}, static_cast(0), static_cast(127)); + builder.AddDequantizeLinearNode(bias, 0.00156f, + 0, + dq_bias_output, use_contrib_qdq); + input_args.push_back(dq_bias_output); + } + + Node* gemm_node = nullptr; + + if (has_output_q) { + auto* gemm_op_output = builder.MakeIntermediate(); + gemm_node = &builder.AddNode("Gemm", input_args, {gemm_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(gemm_op_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + q3_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q3_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + output_arg, use_contrib_qdq); + } else { + gemm_node = &builder.AddNode("Gemm", input_args, {output_arg}); + } + + if (beta_not_one) { + gemm_node->AddAttribute("beta", 2.0f); + } + }; + + auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + if ((!has_output_q || std::is_same_v)&&(!has_bias || (std::is_same_v && !beta_not_one)) && + (std::is_same_v || std::is_same_v)) { + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], has_output_q ? 1 : 0); + } else { + int q_count = 2; // Q for A and B + int dq_count = 2; // DQ for A and B + if (has_bias) { + dq_count++; + } + if (has_output_q) { + q_count++; + dq_count++; + } + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 0); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], q_count); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], dq_count); + } + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + + if (disable_fastmath) { + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + } + }; + + test_case({2, 2}, {2, 4}); + test_case({13, 15}, {15, 15}); + test_case({2, 2}, {2, 4}, true); // Use com.microsoft QDQ ops +} + +template +void QDQTransformerGemmTests() { + QDQTransformerGemmTests(false, false); + QDQTransformerGemmTests(false, true); + QDQTransformerGemmTests(true, false); + QDQTransformerGemmTests(true, true); + QDQTransformerGemmTests(false, false, true); + QDQTransformerGemmTests(false, true, true); + QDQTransformerGemmTests(true, false, true); + QDQTransformerGemmTests(true, true, true); + // dummy test to disable the fastmath session + QDQTransformerGemmTests(true, true, true, true); +} + +TEST(QDQTransformerTests, Gemm_U8U8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8S8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8U8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8S8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8S8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8U8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8U8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8S8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, MatMul_No_Fusion_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add QDQ + MatMul + auto* matmul_output = builder.MakeIntermediate(); + auto* dq_matmul_output1 = AddQDQNodePair(builder, input1_arg, .004f, 129, use_contrib_qdq); + builder.AddNode("MatMul", {dq_matmul_output1, input2_arg}, {matmul_output}); + + // add Q + builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg, use_contrib_qdq); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); +} + +TEST(QDQTransformerTests, MatMul_1st_Input_Int8_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -128, 127); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add DQ with type int8 + auto* dq_output_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .004f, 1, dq_output_1, use_contrib_qdq); + + // add QDQ + MatMul + auto* matmul_output = builder.MakeIntermediate(); + auto* dq_matmul_output2 = AddQDQNodePair(builder, input2_arg, .004f, 129, use_contrib_qdq); + builder.AddNode("MatMul", {dq_output_1, dq_matmul_output2}, {matmul_output}); + + // add Q + builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg, use_contrib_qdq); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); + test_case({23, 13, 13}, {13, 13}, false /*use_contrib_qdq*/); + test_case({22, 11, 13, 15}, {15, 13}, false /*use_contrib_qdq*/); +} + +TEST(QDQTransformerTests, MatMulIntegerToFloat_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* input2_arg = builder.MakeInput(input2_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .0035f, 135, dq_output_1, use_contrib_qdq); + + auto* dq_output_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input2_arg, .0035f, 135, dq_output_2, use_contrib_qdq); + + builder.AddNode("MatMul", {dq_output_1, dq_output_2}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); + test_case({23, 13, 13}, {13, 13}, false /*use_contrib_qdq*/); + test_case({22, 11, 13, 15}, {15, 13}, false /*use_contrib_qdq*/); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) && defined(__aarch64) + +} // namespace test +} // namespace onnxruntime + +#endif // defined(__aarch64) && defined(__linux__) && !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc new file mode 100644 index 0000000000000..75e0c06b04f0d --- /dev/null +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -0,0 +1,305 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// Licensed under the MIT License. + +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/providers/run_options_config_keys.h" +#include "test/common/dnnl_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "default_providers.h" + +#if defined(__aarch64__) && defined(__linux__) + +namespace onnxruntime { +namespace test { + +namespace { + +const onnxruntime::RunOptions run_options = []() { + onnxruntime::RunOptions options{}; + ORT_THROW_IF_ERROR(options.config_options.AddConfigEntry(kOpTesterRunOptionsConfigTestTunableOp, "true")); + return options; +}(); + +const constexpr auto run_with_tunable_op = &run_options; + +} // namespace + +template +struct MatMulTestData { + std::string name; + std::vector input0_dims; + std::vector input1_dims; + std::vector expected_dims; + std::vector expected_vals; +}; + +template +std::vector> GenerateTestCases() { + std::vector> test_cases; + test_cases.push_back( + {"test padding and broadcast A > B", + {3, 1, 1, 6}, + {2, 6, 7}, + {3, 2, 1, 7}, + {385, 400, 415, 430, 445, 460, 475, 1015, 1030, 1045, 1060, 1075, 1090, 1105, 1015, 1066, 1117, 1168, 1219, 1270, 1321, 3157, 3208, 3259, 3310, 3361, 3412, 3463, 1645, 1732, 1819, 1906, 1993, 2080, 2167, 5299, 5386, 5473, 5560, 5647, 5734, 5821}}); + + test_cases.push_back( + {"test padding and broadcast B > A", + {2, 3, 12}, + {3, 2, 12, 3}, + {3, 2, 3, 3}, + {1518, 1584, 1650, 3894, 4104, 4314, 6270, 6624, 6978, 26574, 27072, 27570, 34134, 34776, 35418, 41694, 42480, 43266, 6270, 6336, 6402, 19014, 19224, 19434, 31758, 32112, 32466, 62430, 62928, 63426, 80358, 81000, 81642, 98286, 99072, 99858, 11022, 11088, 11154, 34134, 34344, 34554, 57246, 57600, 57954, 98286, 98784, 99282, 126582, 127224, 127866, 154878, 155664, 156450}}); + + test_cases.push_back( + {"test 2D", + {8, 6}, + {6, 6}, + {8, 6}, + {330, 345, 360, 375, 390, 405, 870, 921, 972, 1023, 1074, 1125, 1410, 1497, 1584, 1671, 1758, 1845, 1950, 2073, 2196, 2319, 2442, 2565, 2490, 2649, 2808, 2967, 3126, 3285, 3030, 3225, 3420, 3615, 3810, 4005, 3570, 3801, 4032, 4263, 4494, 4725, 4110, 4377, 4644, 4911, 5178, 5445}}); + + test_cases.push_back( + {"test 2D special", + {2, 2, 16}, + {16, 4}, + {2, 2, 4}, + {4960, 5080, 5200, 5320, 12640, 13016, 13392, 13768, 20320, 20952, 21584, 22216, 28000, 28888, 29776, 30664}}); + + test_cases.push_back( + {"test 2D special 2", + {2, 2, 9}, + {1, 9, 4}, + {2, 2, 4}, + {816, 852, 888, 924, 2112, 2229, 2346, 2463, 3408, 3606, 3804, 4002, 4704, 4983, 5262, 5541}}); + + test_cases.push_back( + {"test 2D special 3", + {2, 12}, + {1, 1, 12, 3}, + {1, 1, 2, 3}, + {1518, 1584, 1650, 3894, 4104, 4314}}); + + test_cases.push_back( + {"test 3D batch", + {3, 1, 18}, + {3, 18, 2}, + {3, 1, 2}, + { + // clang-format off + 3570, 3723, + 26250, 26727, + 72258, 73059, + // clang-format on + }}); + + test_cases.push_back( + {"test 4D batch", + {2, 2, 1, 20}, + {2, 2, 20, 2}, + {2, 2, 1, 2}, + { + // clang-format off + 4940, 5130, + 36140, 36730, + 99340, 100330, + 194540, 195930, + // clang-format on + }}); + + return test_cases; +} + +template +void RunMatMulTest(int32_t opset_version, bool is_a_constant, bool is_b_constant, bool disable_fastmath) { + for (auto t : GenerateTestCases()) { + SCOPED_TRACE("test case: " + t.name); + + OpTester test("MatMul", opset_version); + + int64_t size0 = TensorShape::FromExistingBuffer(t.input0_dims).SizeHelper(0, t.input0_dims.size()); + std::vector input0_vals = ValueRange(size0); + + test.AddInput("A", t.input0_dims, input0_vals, is_a_constant); + + int64_t size1 = TensorShape::FromExistingBuffer(t.input1_dims).SizeHelper(0, t.input1_dims.size()); + std::vector input1_vals = ValueRange(size1); + test.AddInput("B", t.input1_dims, input1_vals, is_b_constant); + + test.AddOutput("Y", t.expected_dims, t.expected_vals); + + // OpenVINO EP: Disabled temporarily matmul broadcasting not fully supported + // Disable TensorRT because of unsupported data type + std::unordered_set excluded_providers{kTensorrtExecutionProvider, kOpenVINOExecutionProvider}; + if (t.name == "test 2D empty input") { + // NNAPI: currently fails for the "test 2D empty input" case + excluded_providers.insert(kNnapiExecutionProvider); + } + + if ("test padding and broadcast A > B" == t.name || "test 2D empty input" == t.name) { + // QNN can't handle 0 shap + excluded_providers.insert(kQnnExecutionProvider); + } + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) + .Config(so) + .RunWithConfig(); + + if (disable_fastmath) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) + .Config(so) + .RunWithConfig(); + } + } +} + +template +void RunMatMulTest(int32_t opset_version) { + RunMatMulTest(opset_version, false, false, false); +} + +TEST(MathOpTest, MatMulFloatType_FastMath) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; + } + RunMatMulTest(7, false, false, false); +} + +TEST(MathOpTest, MatMulFloatTypeInitializer_FastMath) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; + } + RunMatMulTest(7, false, true, false); +} + +TEST(MathOpTest, MatMulInt32Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint32Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulInt64Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint64Type_FastMath) { + RunMatMulTest(9); +} + +#ifndef ENABLE_TRAINING +// Prepacking is disabled in full training build so no need to test the feature in a training build. +TEST(MathOpTest, MatMulSharedPrepackedWeights_FastMath) { + OpTester test("MatMul"); + + std::vector b_init_values(32, 1.0f); + test.AddInput("A", {8, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + // B is to be an initializer for triggering pre-packing + test.AddInput("B", {4, 8}, b_init_values, true); + + test.AddOutput("Y", {8, 8}, + {10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f}); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({4, 8}), + b_init_values.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + SessionOptions so; + // Set up B as a shared initializer to be shared between sessions + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + // We want all sessions running using this OpTester to be able to share pre-packed weights if applicable + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + // Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP + // and we want to ensure that it is available in this build + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + // Session 1 + { + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + // Assert that no pre-packed weights have been shared thus far + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + auto number_of_elements_in_shared_prepacked_buffers_container = + test.GetNumPrePackedWeightsShared(); + // Assert that the number of elements in the shared container + // is the same as the number of weights that have been pre-packed + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container); + + // On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements + // that have been pre-packed will be zero in which case we do not continue with the testing + // of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all. + if (number_of_pre_packed_weights_counter_session_1 == 0) + return; + + // Session 2 + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + // Assert that the same number of weights were pre-packed in both sessions + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); + + // Assert that the number of pre-packed weights that were shared equals + // the number of pre-packed weights in the second session + ASSERT_EQ(number_of_pre_packed_weights_counter_session_2, + static_cast(number_of_shared_pre_packed_weights_counter)); + } +} + +#endif + +// Dummy run to disable the FastMath mode for the current session +TEST(MathOpTest, MatMulUint64Type_DisableFastMath) { + RunMatMulTest(9, false, false, true); +} + +} // namespace test +} // namespace onnxruntime +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index 3d53d4a3a0193..64ebe24188762 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. // Licensed under the MIT License. #include "test/compare_ortvalue.h" @@ -65,6 +66,54 @@ const char* ElementTypeToString(MLDataType type) { return DataTypeImpl::ToString(type); } +#if defined(__aarch64__) && defined(__linux__) +template +std::pair CheckCosineSimilarity(const Tensor& outvalue, const Tensor& expected_value) { + const size_t tensor_size = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); + std::pair res = std::make_pair(COMPARE_RESULT::SUCCESS, ""); + const T cosine_similarity_threshold = 0.99f; + + T dot = 0.0f, denom_a = 0.0f, denom_b = 0.0f; + for (size_t i = 0u; i < tensor_size; ++i) { + if (isnan(expected_output[i]) && isnan(real_output[i])) + continue; + if (isinf(expected_output[i]) && isinf(real_output[i])) + continue; + dot += expected_output[i] * real_output[i]; + denom_a += expected_output[i] * expected_output[i]; + denom_b += real_output[i] * real_output[i]; + } + + T cos_factor = abs(dot / (sqrt(denom_a) * sqrt(denom_b))); + if (cos_factor < cosine_similarity_threshold) { + res.first = COMPARE_RESULT::RESULT_DIFFERS; + std::ostringstream oss; + oss << std::hex << "results differed, cosine similarity factor is " << cos_factor << "."; + res.second = oss.str(); + } + return res; +} + +template +std::pair CheckCloseMatch(const Tensor& outvalue, const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); + const T close_match_threshold = 1.0; + + for (size_t di = 0; di != size1; ++di) { + const T diff = expected_output[di] - real_output[di]; + if (std::fabs(diff) > close_match_threshold) { + std::ostringstream oss; + oss << "expected " << expected_output[di] << ", got " << real_output[di]; + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} +#endif /** * @brief Check if two values are closely matched with given tolerance. @@ -207,6 +256,37 @@ std::pair CompareTwoTensors(const Tensor& outvalue, oss << "shape mismatch, expect " << expected_tensor.Shape().ToString() << " got " << outvalue.Shape().ToString(); return std::make_pair(COMPARE_RESULT::SHAPE_MISMATCH, oss.str()); } + +#if defined(__aarch64__) && defined(__linux__) + if (isnan(per_sample_tolerance) || isnan(per_sample_tolerance)) { + if (outvalue.IsDataType()) { + return CheckCosineSimilarity(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCosineSimilarity(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else { + return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, ""); + } + } +#endif + if (outvalue.IsDataType()) { return CompareFloatResult(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, post_processing); From 24b74aebcbd5fbaaa44ca41143b3b6afe3207978 Mon Sep 17 00:00:00 2001 From: Linnea May Date: Mon, 22 Jan 2024 15:37:09 -0800 Subject: [PATCH 08/45] [DML] Register DML operators for opset 19 (#16939) ### Description Register DML operators for opset 19. - Cast19 - Castlike19 - Constant19 - Equal19 - Identity19 - QuantizeLinear19 - DequantizeLinear19 - Reshape19 - Shape19 - Size ### Motivation and Context --------- Co-authored-by: linnealovespie --- docs/OperatorKernels.md | 27 ++++++++++++------ .../src/Operators/DmlOperatorCast.cpp | 3 +- .../src/Operators/DmlOperatorElementWise.cpp | 28 +++++++++++-------- .../src/Operators/OperatorRegistration.cpp | 10 +++++++ .../dml/OperatorAuthorHelper/OperatorHelper.h | 3 ++ .../OperatorAuthorHelper/OperatorVersions.h | 10 +++++++ .../cpu/tensor/quantize_linear_test.cc | 10 ------- 7 files changed, 59 insertions(+), 32 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9ecc58bee0725..9a2a7ac89bbb3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -922,10 +922,12 @@ Do not modify directly.* |BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float), tensor(float16)| @@ -952,7 +954,8 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -961,7 +964,8 @@ Do not modify directly.* |DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||11+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||7+|**T** = tensor(float), tensor(float16)
**T1** = tensor(bool)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| @@ -1004,7 +1008,8 @@ Do not modify directly.* |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1099,7 +1104,8 @@ Do not modify directly.* |||7+|**T** = tensor(float), tensor(float16)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| @@ -1150,7 +1156,8 @@ Do not modify directly.* |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| @@ -1178,7 +1185,8 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| @@ -1188,7 +1196,8 @@ Do not modify directly.* |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp index 76b9b308fe98f..45ff25c4fdd90 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp @@ -29,7 +29,7 @@ class DmlOperatorCast : public DmlOperator castDesc.OutputTensor = outputDescs.data(); DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; - + SetDmlOperatorDesc(opDesc, kernelInfo); } @@ -49,5 +49,6 @@ class DmlOperatorCast : public DmlOperator DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast); DML_OP_DEFINE_CREATION_FUNCTION(CastLike15, DmlOperatorCast); +DML_OP_DEFINE_CREATION_FUNCTION(CastLike19, DmlOperatorCast); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index ab8ddbfe91bf0..16bb10f004f91 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -487,7 +487,7 @@ class DmlOperatorElementwisePow : public DmlOperator Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {}; opDesc.InputTensor = &inputDescs[0]; @@ -497,11 +497,11 @@ class DmlOperatorElementwisePow : public DmlOperator SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo); } else - { + { Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; opDesc.InputTensor = &inputDescs[0]; @@ -519,13 +519,16 @@ class DmlOperatorElementwiseQLinear : public DmlOperator public: DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) { - ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 3); + + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + Initialize(kernelInfo, std::nullopt, std::nullopt); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); const uint32_t outputShapeDimCount = gsl::narrow_cast(outputShape.size()); - - Initialize(kernelInfo, std::nullopt, std::nullopt); + const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType(); + bool hasZeroPointTensor = kernelInfo.IsInputValid(2); uint32_t axis = 0; @@ -541,9 +544,14 @@ class DmlOperatorElementwiseQLinear : public DmlOperator axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount, /*validateAxis*/ false); } - // Explicitly reshape each of the inputs after the first input (scale and zero point tensors). + // Explicitly reshape each of the inputs after the first input (scale tensor and optional zero point tensor). for (uint32_t index = 1, inputCount = gsl::narrow_cast(m_inputTensorDescs.size()); index < inputCount; ++index) { + if (!kernelInfo.IsInputValid(index)) + { + continue; + } + auto edgeDesc = kernelInfo.GetInputEdgeDescription(index); assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor); @@ -587,12 +595,8 @@ class DmlOperatorElementwiseQLinear : public DmlOperator TOperatorDesc opDesc = {}; opDesc.InputTensor = &inputDescs[0]; opDesc.ScaleTensor = &inputDescs[1]; - opDesc.ZeroPointTensor = &inputDescs[2]; + opDesc.ZeroPointTensor = hasZeroPointTensor ? &inputDescs[2] : nullptr; opDesc.OutputTensor = &outputDescs[0]; - - TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1); - TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2); - SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 15a8051953c79..18e29c8b99ced 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -436,6 +436,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMul); DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMulActivation); DML_OP_EXTERN_CREATION_FUNCTION(Cast); DML_OP_EXTERN_CREATION_FUNCTION(CastLike15); +DML_OP_EXTERN_CREATION_FUNCTION(CastLike19); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyFromHost); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyToHost); DML_OP_EXTERN_CREATION_FUNCTION(TopK7); @@ -785,6 +786,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY(13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(14, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(16, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(19, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, @@ -798,6 +800,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Elementwise {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -857,8 +860,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, @@ -943,6 +948,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported)}, {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, {REG_INFO( 13, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 19, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, @@ -1004,7 +1010,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported)}, @@ -1015,8 +1023,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, + {REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, + {REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 0e0e6bb1eaf5c..0d425997e6a6a 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1606,6 +1606,7 @@ using ShapeInferenceHelper_Expand = ExpandHelper; using ShapeInferenceHelper_Reshape7 = ReshapeHelper; using ShapeInferenceHelper_Reshape13 = ReshapeHelper; using ShapeInferenceHelper_Reshape14 = ReshapeHelper; +using ShapeInferenceHelper_Reshape19 = ReshapeHelper; using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper; using ShapeInferenceHelper_Tile = TileHelper; using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper; @@ -1725,6 +1726,7 @@ using ShapeInferenceHelper_Identity7 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; @@ -1750,6 +1752,7 @@ using ShapeInferenceHelper_CumSum14 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Range = RangeHelper; using ShapeInferenceHelper_CastLike15 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_CastLike19 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DmlFusedConv = ConvHelper; using ShapeInferenceHelper_DmlFusedConvTranspose = ConvTransposeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 8438bc620712c..79efc2d2836fe 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -413,6 +413,16 @@ namespace OperatorHelper namespace OnnxOperatorSet19 { static const int sc_sinceVer_AveragePool = 19; + static const int sc_sinceVer_Cast = 19; + static const int sc_sinceVer_CastLike = 19; + static const int sc_sinceVer_Constant = 19; + static const int sc_sinceVer_Equal = 19; + static const int sc_sinceVer_Identity = 19; + static const int sc_sinceVer_QuantizeLinear = 19; + static const int sc_sinceVer_DequantizeLinear = 19; + static const int sc_sinceVer_Reshape = 19; + static const int sc_sinceVer_Shape = 19; + static const int sc_sinceVer_Size = 19; } namespace MsftOperatorSet1 diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 026bb07edf44c..0c8d6c46d4639 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -34,11 +34,6 @@ TEST(DequantizeLinearOpTest, Int8) { // scalar zero & scale with int8 TEST(DequantizeLinearOpTest, Int32) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect"; - } - OpTester test("DequantizeLinear", 10); std::vector dims{4}; test.AddInput("x", dims, {-30, -3, 100, 127}); @@ -98,11 +93,6 @@ TEST(DequantizeLinearOpMLFloat16Test, Scalar) { // dequantize without zero point TEST(DequantizeLinearOpTest, Without_Zero_Point) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect"; - } - OpTester test("DequantizeLinear", 10); test.AddInput("x", {}, {100}); test.AddInput("x_scale", {}, {2.0f}); From e283cdb21857170848e9b8a8fbca24d0463b4193 Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Mon, 22 Jan 2024 15:44:57 -0800 Subject: [PATCH 09/45] Fix Fuzz Testing CI (#19228) ### Description Add BuildArch To verify: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=400952&view=logs&j=5b022bb4-70a7-5401-8766-a8a7802c7150&t=291e85c7-5547-590b-50de-4e01fcd4eba3&l=14 ### Motivation and Context --- tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index b8f9566274acc..db39c2cd2087f 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -28,7 +28,7 @@ jobs: parameters: EnvSetupScript: $(EnvSetupScript) DownloadCUDA: false - BuildArch: $(buildArch) + BuildArch: x64 BuildConfig: $(BuildConfig) MachinePool: 'onnxruntime-Win-CPU-2022' WithCache: true From 2e0a388c36b92bc412dfa8ad45af23c7f28a4d49 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 23 Jan 2024 07:53:26 +0800 Subject: [PATCH 10/45] [js/webgpu] Add HardSigmoid support (#19215) ### Description This op is required in mobilenetv3-small-100. With this PR, mobilenetv3-small-100 model becomes less than 10 ms from over 100 ms on ADL. --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 20 +++++++++++++++++++ js/web/test/suite-test-list.jsonc | 6 +++--- .../providers/js/js_execution_provider.cc | 2 ++ .../core/providers/js/operators/unary.cc | 3 +++ 6 files changed, 30 insertions(+), 3 deletions(-) diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2f510308d9306..2557971eb4ded 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -52,6 +52,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 90e02da986b8f..cc504093ca0d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], + ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index a25e7fe4229b4..82311d72e58b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); }; +export interface HardSigmoidAttributes extends AttributeWithCacheKey { + readonly alpha: number; + readonly beta: number; +} + +export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes => + createAttributeWithCacheKey(attributes as { + alpha: number; + beta: number; + }); + +export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'HardSigmoid', + a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ + attributes.beta})))`, + undefined, attributes.cacheKey)); +}; + export const sin = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin')); }; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 033b3b3f4b0f5..373b3c645df57 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -597,9 +597,9 @@ // // "test_hardmax_example", // // "test_hardmax_negative_axis", // // "test_hardmax_one_hot", - // // "test_hardsigmoid_default", - // // "test_hardsigmoid_example", - // // "test_hardsigmoid", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", "test_if", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index c2ff2ebc39e13..af9658271d210 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log); @@ -392,6 +393,7 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(13, Erf), KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), KERNEL_CREATE_INFO(13, Log), diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 78563d30b0136..9082527e3a8d7 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -77,6 +77,9 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5) +JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid) + JSEP_KERNEL_IMPL(Log, Log) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) JSEP_ELEMENTWISE_KERNEL(Log, 13, Log) From d226e40856738531cf8b481b07379545f7cfefe2 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 23 Jan 2024 08:08:55 +0800 Subject: [PATCH 11/45] [js/webgpu] set query type in onRunStart (#19202) ### Description `env.webgpu.profiling` is a global flag. It may change before each session.run. So the best place is to update it in `onRunStart` event. After this, we can directly check `this.queryType`'s value. Without this pr, we need to make sure that `getCommandEncoder()` is called before checking `this.queryType`. Otherwise, it may happen that `pendingKernels`'s length is not equal to `pendingDispatchNumber`'s length. See the two ugly workarounds [1)](https://github.com/microsoft/onnxruntime/pull/18989/commits/e630dbf528fc3a955702cceb968930d0abdfc652#diff-006fc84d3997f96a29b8033bd2075d6a0a9509211bd5812a6b934fc74fedfd9dR267-R268) and [2)](https://github.com/microsoft/onnxruntime/pull/18989/commits/e630dbf528fc3a955702cceb968930d0abdfc652#diff-618fe297fbe7a1da586380163b8fd2627311ccc217640a3c5cdc9c17a33472c1R73-R80) if we don't introduce `onRunStart`. Or we need to call `setQueryType` in each kernel run. --- js/web/lib/wasm/binding/ort-wasm.d.ts | 4 ++++ js/web/lib/wasm/jsep/backend-webgpu.ts | 9 +++++---- js/web/lib/wasm/wasm-core-impl.ts | 2 +- onnxruntime/wasm/js_internal_api.js | 3 +++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 9d4d5875310b7..68054210e79a7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule { jsepCreateDownloader: (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Called when InferenceSession.run started. + */ + jsepOnRunStart: () => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 2956ec1cad4da..afef7042a4280 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -208,7 +208,7 @@ export class WebGpuBackend { Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); - // init queryType, which is necessary for createKernel + // init queryType, which is necessary for InferenceSession.create this.setQueryType(); } @@ -223,8 +223,6 @@ export class WebGpuBackend { if (!this.commandEncoder) { this.commandEncoder = this.device.createCommandEncoder(); - // refresh queryType, as sometimes we only need to enable query for a specific run - this.setQueryType(); if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { this.querySet = this.device.createQuerySet({ type: 'timestamp', @@ -639,6 +637,7 @@ export class WebGpuBackend { return createView(data.buffer, type); }; } + // #endregion writeTimestamp(index: number): void { if (this.queryType !== 'inside-passes') { return; @@ -657,5 +656,7 @@ export class WebGpuBackend { } } } - // #endregion + onRunStart(): void { + this.setQueryType(); + } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5821fac3c468f..8768643fa7257 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -488,8 +488,8 @@ export const run = async( } } + wasm.jsepOnRunStart?.(); let errorCode: number; - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 25ece9c700d5d..7c70515e73eab 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -186,4 +186,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; + Module['jsepOnRunStart'] = () => { + return backend['onRunStart'](); + }; }; From 37d14d78960fb1ba54c0bb2dc3be740e93d2ca15 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 22 Jan 2024 18:14:41 -0800 Subject: [PATCH 12/45] [QNN EP] Create Windows ARM64 nightly python package (#19128) ### Description Adds a job to create a nightly python package for ORT/QNN on Windows ARM64. Must build onnxruntime-qnn with python 3.11 and numpy 1.25. **Note: pipeline run may take up to 3 hrs** ### Motivation and Context Make it possible to get a nightly python package with the latest updates to QNN EP. Issue #19161 --- .../azure-pipelines/py-packaging-pipeline.yml | 8 +- .../templates/py-packaging-stage.yml | 13 ++ .../templates/py-win-arm64-qnn.yml | 165 ++++++++++++++++++ 3 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 06cca0068523d..5349b1ca67ab1 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -29,6 +29,11 @@ parameters: type: boolean default: true +- name: enable_windows_arm64_qnn + displayName: 'Whether Windows ARM64 package with QNN EP is built.' + type: boolean + default: true + - name: build_py_parameters displayName: 'Specify extra build parameters' type: string @@ -64,5 +69,6 @@ stages: enable_windows_gpu: ${{ parameters.enable_windows_gpu }} enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} + enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} build_py_parameters: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} \ No newline at end of file + cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 8669a883c31f1..297498843c38d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -35,6 +35,11 @@ parameters: type: boolean default: true +- name: enable_windows_arm64_qnn + displayName: 'Whether Windows ARM64 package with QNN EP is built.' + type: boolean + default: true + # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type type: string @@ -446,3 +451,11 @@ stages: machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} + + - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: + - template: py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: 'qnn-v2.18.0.240101_win' + PYTHON_VERSION: '3.11' + NUMPY_VERSION: '1.25.2' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml new file mode 100644 index 0000000000000..adf7aa9c43205 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -0,0 +1,165 @@ +parameters: + +- name: MACHINE_POOL + type: string + default: 'onnxruntime-qnn-windows-vs-2022-arm64' + +- name: QNN_SDK + displayName: QNN Windows SDK path + type: string + default: qnn-v2.18.0.240101_win + +- name: PYTHON_VERSION + type: string + default: '3.11' + +- name: NUMPY_VERSION + type: string + default: '1.25.2' + +- name: ENV_SETUP_SCRIPT + type: string + default: '' + +- name: BUILD_PY_PARAMETERS + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +jobs: +- job: Win_py_arm64_qnn_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} + timeoutInMinutes: 210 + workspace: + clean: all + pool: + name: ${{ parameters.MACHINE_POOL }} + variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' + VSGenerator: 'Visual Studio 17 2022' + QNN_SDK_ROOTDIR: 'C:\data\qnnsdk\${{parameters.QNN_SDK}}' + steps: + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + DIR C:\data\qnnsdk + displayName: Check available QNN SDKs + + - script: | + MKDIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 + XCOPY /s /y /h /e /c /q "C:\Python\Python311\*.*" $(Agent.ToolsDirectory)\Python\3.11.0\arm64\ + COPY NUL $(Agent.ToolsDirectory)\Python\3.11.0\arm64.complete + DIR $(Agent.ToolsDirectory)\Python + DIR $(Agent.ToolsDirectory)\Python\3.11.0 + DIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 + displayName: Copy python 3.11.0 version to agent tools directory + + - task: UsePythonVersion@0 + inputs: + versionSpec: ${{ parameters.PYTHON_VERSION }} + addToPath: true + architecture: 'arm64' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', 'numpy==${{parameters.NUMPY_VERSION}}']) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: set-nightly-build-option-variable-step.yml + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --use_qnn + --qnn_home $(QNN_SDK_ROOTDIR) + --enable_pybind + --parallel --update + --numpy_version ${{ parameters.NUMPY_VERSION }} + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64' + configuration: RelWithDebInfo + msbuildArchitecture: 'arm64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd,*.dll' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' From b2aec41a8309bc2dced74a991b1f3c311e037e3d Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 22 Jan 2024 19:17:04 -0800 Subject: [PATCH 13/45] [ROCm] enable hipGraph (#18382) This ports the cudaGraph support from the CUDA EP to the ROCM EP's hipGraph. --- cmake/onnxruntime_unittests.cmake | 7 ++ .../core/session/onnxruntime_c_api.h | 3 + .../providers/rocm/rocm_execution_provider.cc | 77 +++++++++++- .../providers/rocm/rocm_execution_provider.h | 24 ++++ .../rocm/rocm_execution_provider_info.cc | 3 + .../rocm/rocm_execution_provider_info.h | 2 + .../providers/rocm/rocm_provider_factory.cc | 2 + onnxruntime/core/session/inference_session.cc | 52 +++++--- .../core/session/provider_bridge_ort.cc | 1 + onnxruntime/test/shared_lib/test_inference.cc | 112 +++++++++++++++--- 10 files changed, 241 insertions(+), 42 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index fa395802d95ff..0987d6d164dbd 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1277,6 +1277,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if (onnxruntime_USE_CUDA) list(APPEND onnxruntime_shared_lib_test_LIBS cudart) endif() + if (onnxruntime_USE_ROCM) + list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) + endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() @@ -1294,6 +1297,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) endif() + if (onnxruntime_USE_ROCM) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) + target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 101a578ec3e1d..2ce9d361e8e56 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -496,6 +496,7 @@ typedef struct OrtROCMProviderOptions { has_user_compute_stream{}, user_compute_stream{}, default_memory_arena_cfg{}, + enable_hip_graph{false}, tunable_op_enable{false}, tunable_op_tuning_enable{false}, tunable_op_max_tuning_duration_ms{} {} @@ -548,6 +549,8 @@ typedef struct OrtROCMProviderOptions { */ OrtArenaCfg* default_memory_arena_cfg; + int enable_hip_graph; + /** \brief Enable TunableOp for using. * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d7c5098d9dbe4..d7bec337a6be4 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -170,6 +170,8 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); + + hip_graph_.SetStream(stream); } ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { @@ -177,6 +179,33 @@ ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); } +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_; +} + +void ROCMExecutionProvider::PerThreadContext::CaptureBegin() { + hip_graph_.Reset(); + hip_graph_.CaptureBegin(); +} + +void ROCMExecutionProvider::PerThreadContext::CaptureEnd() { + hip_graph_.CaptureEnd(); + is_graph_captured_ = true; +} + +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + return hip_graph_.Replay(); +} + +void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} + void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( "ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead."); @@ -219,6 +248,11 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in if (info.external_allocator_info.UseExternalAllocator()) { use_ep_level_unified_stream_ = true; stream_ = nullptr; + } else if (info.enable_hip_graph) { + // current hip graph implementation only works with single stream + // use EP level unified stream for all the reqeust + HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking)); + use_ep_level_unified_stream_ = true; } else { stream_ = nullptr; } @@ -322,25 +356,58 @@ Status ROCMExecutionProvider::Sync() const { Status ROCMExecutionProvider::OnRunStart() { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); + if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; + GetPerThreadContext().CaptureBegin(); + } return Status::OK(); } Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { + if (GetPerThreadContext().IsGraphCaptureAllowed()) { + GetPerThreadContext().CaptureEnd(); + // HIP work issued to a capturing stream doesn’t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); + } else { + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); + } + } + if (sync_stream) { HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream_))); } - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to - // release. This didn't happen before because we always call - // GetPerThreadContext in OnRunStart. - if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { + // The reason of !IsGraphCaptureEnabled(): + // If hip graph is enabled, the per thread context will not be released + // because the per thread hip graph needs to be maintained and replayed for + // the next run. + // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): + // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), + // PerThreadContext won't be created and there is nothing to + // release. This didn't happen before because we always call + // GetPerThreadContext in OnRunStart. + if (!IsGraphCaptureEnabled() && + PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { ReleasePerThreadContext(); } return Status::OK(); } +bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { + return info_.enable_hip_graph; +} + +bool ROCMExecutionProvider::IsGraphCaptured() const { + return GetPerThreadContext().IsGraphCaptured(); +} + +Status ROCMExecutionProvider::ReplayGraph() { + return GetPerThreadContext().ReplayGraph(); +} + namespace rocm { // opset 1 to 9 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index c4945b9ac2481..37d5f7b42210f 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -10,6 +10,7 @@ #include "core/framework/execution_provider.h" #include "core/platform/ort_mutex.h" #include "core/providers/rocm/rocm_execution_provider_info.h" +#include "core/providers/rocm/rocm_graph.h" #include "core/providers/rocm/rocm_pch.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" @@ -73,6 +74,9 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -81,6 +85,7 @@ class ROCMExecutionProvider : public IExecutionProvider { ROCMExecutionProviderInfo info_; hipDeviceProp_t device_prop_; bool external_stream_ = false; + // only used when set user external stream or hip graph hipStream_t stream_ = nullptr; bool use_ep_level_unified_stream_ = false; @@ -133,6 +138,13 @@ class ROCMExecutionProvider : public IExecutionProvider { } } + bool IsGraphCaptureAllowed() const; + void CaptureBegin(); + void CaptureEnd(); + bool IsGraphCaptured() const; + Status ReplayGraph(); + void IncrementRegularRunCountBeforeGraphCapture(); + private: rocblas_handle rocblas_handle_ = nullptr; miopenHandle_t miopen_handle_ = nullptr; @@ -141,6 +153,18 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr> constant_ones_double_; std::unique_ptr> constant_ones_half_; std::unique_ptr> constant_ones_bfloat16_; + + // Hip graph with multi threads will be supported in the future, so hip_graph_ + // is put under PerThreadContext. + ROCMGraph hip_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + + // There is chance that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 650635c153640..b557f92287f2b 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -21,6 +21,7 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc"; constexpr const char* kGpuExternalFree = "gpu_external_free"; constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace"; +constexpr const char* kEnableHipGraph = "enable_hip_graph"; constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; @@ -84,6 +85,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P info.miopen_conv_exhaustive_search) .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) + .AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph) .AddValueParser( rocm::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -121,6 +123,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, {rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, {rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)}, + {rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)}, {rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)}, {rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index e35c0cc0afecc..2f549cc1ac143 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -63,6 +63,8 @@ struct ROCMExecutionProviderInfo { // If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best. bool miopen_conv_use_max_workspace{true}; + bool enable_hip_graph{false}; + rocm::TunableOpInfo tunable_op{}; static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 4d88c25469372..88ef666678b3e 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -185,6 +185,7 @@ struct ROCM_Provider : Provider { info.has_user_compute_stream = params->has_user_compute_stream != 0; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; + info.enable_hip_graph = params->enable_hip_graph; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; @@ -215,6 +216,7 @@ struct ROCM_Provider : Provider { rocm_options.user_compute_stream = internal_options.user_compute_stream; } rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; + rocm_options.enable_hip_graph = internal_options.enable_hip_graph; rocm_options.tunable_op_enable = internal_options.tunable_op.enable; rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable; rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8853c8824738..39f47c09f2402 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -153,7 +153,7 @@ static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { // Empty node provider means CPU EP if (!node_provider.empty() && - node_provider != kCudaExecutionProvider && + !(node_provider == kCudaExecutionProvider || node_provider == kRocmExecutionProvider) && node_provider != kCpuExecutionProvider) { nodes_on_cpu_and_cuda_eps_only = false; break; @@ -1715,7 +1715,8 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // Currently CUDA graph is only considered by CUDA EP and TRT EP, and + // HIP graph is only considered by ROCM EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -1728,47 +1729,58 @@ common::Status InferenceSession::Initialize() { // The TRT EP is configured to do a graph capture AND // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). - std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + // + // Check for ROCM EP: + // If the ROCM EP is part of the providers list for this session AND + // The ROCM EP is configured to do a graph capture AND + // All the "compute" graph nodes have been assigned to the ROCM EP, + // Then the ROCM EP is cached for triggering a ReplayGraph() in Run(). + // + std::vector graph_support_ep_list = { + onnxruntime::kTensorrtExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider}; - for (auto& it : cuda_graph_support_ep_list) { + for (auto& it : graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); if (target_ep && target_ep->IsGraphCaptureEnabled()) { - // CUDA Graphs can't work with control flow nodes + // CUDA/HIP Graphs can't work with control flow nodes if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " + << "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs."; ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + "This session cannot use the CUDA/HIP Graph feature as requested by the user " + "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs.")); } - if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0) { // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. if (!AreAllComputeNodesAssignedToCudaEp(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA EP."; + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the CUDA/HIP EP."; ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA EP.")); + "This session cannot use the CUDA/HIP Graph feature as requested by the user " + " as all compute graph nodes have not been partitioned to the CUDA/HIP EP.")); } // Log a warning for the user to know that there are shape subgraphs that will execute on CPU if (HasShapeSubgraphNodes(graph)) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA Graph feature with caution. " + << "Use the CUDA/HIP Graph feature with caution. " << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA graph, " + << "using the representative input used to capture the CUDA/HIP graph, " << "will match the shapes produced in the model for other inputs " << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA Graph feature."; + << "it is safe to use the CUDA/HIP Graph feature."; } } else { // Following code path is for TRT EP currently. @@ -1787,7 +1799,7 @@ common::Status InferenceSession::Initialize() { } } - LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + LOGS(*session_logger_, INFO) << "This session will use the CUDA/HIP Graph feature as requested by the user."; cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); break; // Make sure only one ep can run CUDA graph. } @@ -2477,7 +2489,9 @@ Status InferenceSession::Run(const RunOptions& run_options, // As N+1 inference runs (N for memory allocation and 1 for graph capturing) // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, // so that users just need one session run to capture the graph. - // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP. + // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, + // N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP, + // and the value could be different for other EP. if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3269c9f0f4e4b..3178c13d30eec 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2380,6 +2380,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProvider options->has_user_compute_stream = 0; options->user_compute_stream = nullptr; options->default_memory_arena_cfg = nullptr; + options->enable_hip_graph = false; options->tunable_op_enable = 0; options->tunable_op_tuning_enable = 0; options->tunable_op_max_tuning_duration_ms = 0; diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 6ffe72f81bd24..8dad2c8e2d10d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -43,6 +43,10 @@ #include #endif +#ifdef USE_ROCM +#include +#endif + // Once we use C++17 this could be replaced with std::size template constexpr size_t countof(T (&)[N]) { return N; } @@ -1762,6 +1766,27 @@ TEST(CApiTest, get_allocator_cuda) { } #endif +#ifdef USE_ROCM +TEST(CApiTest, get_allocator_rocm) { + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); + Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); + + Ort::MemoryInfo info_rocm("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); + Ort::Allocator rocm_allocator(session, info_rocm); + + auto allocator_info = rocm_allocator.GetInfo(); + ASSERT_TRUE(info_rocm == allocator_info); + void* p = rocm_allocator.Alloc(1024); + ASSERT_NE(p, nullptr); + rocm_allocator.Free(p); + + auto mem_allocation = rocm_allocator.GetAllocation(1024); + ASSERT_NE(nullptr, mem_allocation.get()); + ASSERT_EQ(1024U, mem_allocation.size()); +} +#endif + TEST(CApiTest, io_binding) { Ort::SessionOptions session_options; Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); @@ -1937,7 +1962,7 @@ TEST(CApiTest, io_binding_cuda) { } #endif -#if defined(USE_CUDA) || defined(USE_TENSORRT) +#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) TEST(CApiTest, basic_cuda_graph) { const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; @@ -1955,7 +1980,7 @@ TEST(CApiTest, basic_cuda_graph) { ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( static_cast(session_options), rel_trt_options.get()) == nullptr); -#else +#elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); @@ -1968,34 +1993,55 @@ TEST(CApiTest, basic_cuda_graph) { ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( static_cast(session_options), rel_cuda_options.get()) == nullptr); +#elif defined(USE_ROCM) + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); - Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#if defined(USE_ROCM) +// local hipify +#define cudaMemcpy hipMemcpy +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost + Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#else + Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#endif - Ort::Allocator cuda_allocator(session, info_cuda); - auto allocator_info = cuda_allocator.GetInfo(); - ASSERT_TRUE(info_cuda == allocator_info); + Ort::Allocator allocator(session, info_mem); + auto allocator_info = allocator.GetInfo(); + ASSERT_TRUE(info_mem == allocator_info); const std::array x_shape = {3, 2}; std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - auto input_data = cuda_allocator.GetAllocation(x_values.size() * sizeof(float)); + auto input_data = allocator.GetAllocation(x_values.size() * sizeof(float)); ASSERT_NE(input_data.get(), nullptr); - cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); // Create an OrtValue tensor backed by data on CUDA memory - Ort::Value bound_x = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(input_data.get()), x_values.size(), + Ort::Value bound_x = Ort::Value::CreateTensor(info_mem, reinterpret_cast(input_data.get()), x_values.size(), x_shape.data(), x_shape.size()); const std::array expected_y_shape = {3, 2}; std::array expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; - auto output_data = cuda_allocator.GetAllocation(expected_y.size() * sizeof(float)); + auto output_data = allocator.GetAllocation(expected_y.size() * sizeof(float)); ASSERT_NE(output_data.get(), nullptr); // Create an OrtValue tensor backed by data on CUDA memory - Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(output_data.get()), + Ort::Value bound_y = Ort::Value::CreateTensor(info_mem, reinterpret_cast(output_data.get()), expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); // Create IoBinding for inputs and outputs. @@ -2008,31 +2054,37 @@ TEST(CApiTest, basic_cuda_graph) { // Check the values against the bound raw memory (needs copying from device to host first) std::array y_values; - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Replay the captured CUDA graph session.Run(Ort::RunOptions(), binding); - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Change the input and replay the CUDA graph again. x_values = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; - cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); binding.SynchronizeInputs(); session.Run(Ort::RunOptions(), binding); - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); expected_y = {10.0f, 40.0f, 90.0f, 160.0f, 250.0f, 360.0f}; ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Clean up binding.ClearBoundInputs(); binding.ClearBoundOutputs(); +#if defined(USE_ROCM) +#undef cudaMemcpy +#undef cudaMemcpyHostToDevice +#undef cudaMemcpyDeviceToHost +#endif } -#ifndef REDUCED_OPS_BUILD // The following test uses some ops not supported in the reduced ops build +#ifndef REDUCED_OPS_BUILD +#if defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, cuda_graph_with_shape_nodes) { const auto& api = Ort::GetApi(); @@ -2053,10 +2105,34 @@ TEST(CApiTest, cuda_graph_with_shape_nodes) { // Successful loading of the ONNX model with shape nodes with cuda graph feature enabled Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); } +#endif // defined(USE_CUDA) || defined(USE_TENSORRT) -#endif +#if defined(USE_ROCM) +TEST(CApiTest, hip_graph_with_shape_nodes) { + const auto& api = Ort::GetApi(); -#endif + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + Ort::SessionOptions session_options; + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); + + // Successful loading of the ONNX model with shape nodes with hip graph feature enabled + Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); +} +#endif // defined(USE_ROCM) + +#endif // REDUCED_OPS_BUILD + +#endif // defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) TEST(CApiTest, create_tensor) { const char* s[] = {"abc", "kmp"}; From 6ca7c1a933e57e0078d8d01eff3a1520098cfed1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 22 Jan 2024 20:42:30 -0800 Subject: [PATCH 14/45] unet fusion for stable diffusion webui (#19227) ### Description Update unet fusion for [stable diffusion webui extension](https://github.com/tianleiwu/Stable-Diffusion-WebUI-OnnxRuntime): (1) Update fusion pattern to support fp16 unet model. (2) Add progress bar (3) Use a cached map to speed up dtype or shape lookup in shape inference result. ### Motivation and Context --- .../tools/transformers/fusion_attention.py | 14 +- .../transformers/fusion_attention_unet.py | 166 ++++++++++++++++-- .../tools/transformers/fusion_embedlayer.py | 18 +- .../tools/transformers/fusion_gemmfastgelu.py | 2 +- .../tools/transformers/fusion_nhwc_conv.py | 15 +- .../python/tools/transformers/fusion_shape.py | 8 +- .../python/tools/transformers/fusion_utils.py | 47 +++-- .../python/tools/transformers/import_utils.py | 20 +++ .../models/stable_diffusion/README.md | 2 +- .../python/tools/transformers/onnx_model.py | 98 ++++++++--- .../tools/transformers/onnx_model_bert.py | 16 +- .../tools/transformers/onnx_model_unet.py | 71 +++++++- 12 files changed, 395 insertions(+), 82 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/import_utils.py diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index d11cb91d98b0c..f48cabd25fc5c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -129,6 +129,9 @@ def __init__( self.num_heads_warning = True self.hidden_size_warning = True + self.shape_infer = None + self.shape_infer_done = True + def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]: """ Detect num_heads and hidden_size from Concat node in the following subgraph: @@ -202,12 +205,15 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] return num_heads, hidden_size def get_add_qk_str(self, add_qk: NodeProto): - shape_infer = self.model.infer_runtime_shape(update=True) - if shape_infer is None: + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is None: return None - input_0_shape = shape_infer.get_edge_shape(add_qk.input[0]) - input_1_shape = shape_infer.get_edge_shape(add_qk.input[1]) + input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0]) + input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1]) if input_0_shape is None or input_1_shape is None: logger.debug(f"one of the inputs of {add_qk} is None") diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 250ec5f3eb159..9a353e7e2d675 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -28,10 +28,19 @@ def __init__( enable_packed_qkv: bool, enable_packed_kv: bool, ): - super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"]) + super().__init__( + model, + "Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention", + ["LayerNormalization"], + ) self.hidden_size = hidden_size self.num_heads = num_heads self.is_cross_attention = is_cross_attention + + # Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA. + # To support LoRA, it is better to use separated Q, K and V inputs in offline optimization, + # and CUDA operator pre-packs those tensors to preferred format based on available kernels. + # In this way, we can support LoRA and get optimal performance at same time. self.enable_packed_qkv = enable_packed_qkv self.enable_packed_kv = enable_packed_kv @@ -170,9 +179,7 @@ def create_attention_node( return None # Sometimes weights are stored in fp16 - if q_weight.data_type == 10: - logger.debug("weights are in fp16. Please run fp16 conversion after optimization") - return None + float_type = q_weight.data_type qw = NumpyHelper.to_array(q_weight) kw = NumpyHelper.to_array(k_weight) @@ -212,7 +219,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], vals=qkv_weight, ) @@ -235,8 +242,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[attention_node_name + "_input"], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_qkv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -251,7 +261,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qw_in_size, qkv_weight_dim], vals=qkv_weight, ) @@ -280,7 +290,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[kv_weight.shape[0], kv_weight.shape[1]], vals=kv_weight, ) @@ -303,8 +313,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[k_matmul.output[0]], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_kv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -317,7 +330,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_bias_dim], vals=qkv_bias, ) @@ -330,7 +343,7 @@ def create_attention_node( attention_node_name + "_qkv_bias", ] else: - attention_inputs = [attention_node_name + "_input"] + attention_inputs = [attention_node_name + "_qkv_input"] else: if not self.enable_packed_kv: attention_inputs = [ @@ -342,7 +355,7 @@ def create_attention_node( else: attention_inputs = [ q_matmul.output[0], - k_matmul.output[0], + attention_node_name + "_kv_input", ] attention_node = helper.make_node( @@ -839,6 +852,9 @@ def create_attention_node_lora( return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node): + return + node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) # In SD 1.5, for self attention, LayerNorm has parent Reshape @@ -1168,3 +1184,125 @@ def match_lora_path( return (lora_mul_node, lora_matmul_1_node) return None + + def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node): + """Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension""" + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0]) + if entry_path is None: + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0]) + if entry_path is None: + return False + _cast, node_before_layernorm = entry_path + + root_input = node_before_layernorm.output[0] + + children_nodes = input_name_to_nodes[root_input] + skip_add = None + for node in children_nodes: + if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet + skip_add = node + break + if skip_add is None: + return False + + match_qkv = self.match_qkv_a1111(root_input, skip_add) + if match_qkv is None: + return False + + ( + reshape_qkv, + transpose_qkv, + reshape_q, + matmul_q, + matmul_k, + matmul_v, + ) = match_qkv + + cast_q = self.model.match_parent(matmul_q, "Cast", 0) + cast_k = self.model.match_parent(matmul_k, "Cast", 0) + cast_v = self.model.match_parent(matmul_v, "Cast", 0) + if not ( + cast_q is not None + and cast_k is not None + and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k) + and cast_k == cast_v + ): + return False + + if cast_q.input[0] != normalize_node.output[0]: + return False + + attention_last_node = reshape_qkv + + q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return False + + q_hidden_size = self.get_hidden_size(normalize_node) + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + matmul_k, + matmul_v, + q_num_heads, + q_hidden_size, + input=matmul_q.input[0], + output=attention_last_node.output[0], + ) + if new_node is None: + return False + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True + return True + + def match_qkv_a1111(self, root_input, skip_add): + """Match Q, K and V paths exported by A1111 (stable diffusion webui) extension""" + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"], + [another_input, None, None, 0, 0, 0], + ) + + if qkv_nodes is None: + return None + + (_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return None + (_, _, _, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path( + einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None] + ) + if qk_nodes is not None: + (_, _, _softmax_qk, _, einsum_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return None + + q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return None + (_, _transpose_q, reshape_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return None + + (_, _, _, matmul_k) = k_nodes + + return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index bc38399e3cce5..42156d9123383 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -28,7 +28,9 @@ def __init__(self, model: OnnxModel, description: str = "no mask"): description, ) self.utils = FusionUtils(model) - self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = None + self.shape_infer_done = False + # The following will be reset in each fuse call of FusionEmbedLayerNormalization self.attention = None self.embed_node = None @@ -329,9 +331,13 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None position_ids = position_embedding_gather.input[1] - if self.shape_infer_helper is not None: - input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids) - position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids) + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is not None: + input_ids_shape = self.shape_infer.get_edge_shape(input_ids) + position_ids_shape = self.shape_infer.get_edge_shape(position_ids) assert input_ids_shape and position_ids_shape if not ( len(input_ids_shape) == 2 @@ -345,11 +351,11 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit ) return False - if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids): + if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids): logger.info( "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format( input_ids_shape, - self.shape_infer_helper.get_edge_shape(segment_ids), + self.shape_infer.get_edge_shape(segment_ids), ) ) return False diff --git a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py index f1d803a3cc082..4d9913f427b37 100644 --- a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py @@ -32,7 +32,7 @@ def get_dimensions(self, input_name: str) -> Union[int, None]: return self.get_dimensions_from_tensor_proto(graph_input) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py index 141ebb1f95a11..5233fdf272fbd 100644 --- a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py +++ b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py @@ -7,7 +7,8 @@ from typing import List from fusion_base import Fusion -from onnx import TensorProto, helper, numpy_helper +from fusion_utils import FusionUtils +from onnx import helper, numpy_helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -19,6 +20,7 @@ class FusionNhwcConv(Fusion): def __init__(self, model: OnnxModel, update_weight=False): super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv") self.update_weight = update_weight + self.fusion_utils = FusionUtils(model) def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): """Append a Transpose node after an input""" @@ -49,6 +51,15 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): if len(weight.shape) != 4: return + dtype = self.model.get_dtype(nhwc_conv_input) + if not (dtype is not None and weight_tensor.data_type == dtype): + cast_node = self.fusion_utils.add_cast_node( + input_name=nhwc_conv_input, + to_type=weight_tensor.data_type, + output_name_to_node=output_name_to_node, + ) + nhwc_conv_input = cast_node.output[0] + if self.update_weight: # Transpose weights from NCHW to NHWC weight = weight.transpose(0, 2, 3, 1) @@ -56,7 +67,7 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): weight_name = node_name + "_weight_NHWC" self.add_initializer( name=weight_name, - data_type=TensorProto.FLOAT, + data_type=weight_tensor.data_type, dims=list(weight.shape), vals=weight, ) diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index bc32d78eda66c..dfa77fc7d0221 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -29,12 +29,12 @@ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[i return None def get_dimensions(self, input_name: str) -> Union[int, None]: - graph_input = self.model.find_graph_input(input_name) - if graph_input: - return self.get_dimensions_from_tensor_proto(graph_input) + shape = self.model.get_shape(input_name) + if shape is not None: + return len(shape) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index afc968fab46c1..726c587ff7043 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple +from typing import Optional, Tuple import numpy from numpy import array_equal, ndarray @@ -29,17 +29,7 @@ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: return False, input_name def cast_input(self, input_name: str, target_type="int32"): - cast_output = input_name + "_" + target_type - - # Avoid consequent Cast nodes. - inputs = [input_name] - output_name_to_node = self.model.output_name_to_node() - if input_name in output_name_to_node: - parent_node = output_name_to_node[input_name] - if parent_node and parent_node.op_type == "Cast": - inputs = [parent_node.input[0]] - - cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output]) + output_name = input_name + "_" + target_type if target_type == "int32": to_type = int(TensorProto.INT32) @@ -50,10 +40,36 @@ def cast_input(self, input_name: str, target_type="int32"): else: raise ValueError("Invalid target_type: {target_type}") + cast_node = self.add_cast_node(input_name, to_type, output_name) + + return output_name, cast_node + + def add_cast_node( + self, + input_name: str, + to_type: int, + output_name: Optional[str] = None, + output_name_to_node=None, + graph_name: Optional[str] = None, + ): + if output_name is None: + output_name = input_name + f"_cast_to_{to_type}" + + # Avoid consequent Cast nodes. + inputs = [input_name] + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + if input_name in output_name_to_node: + parent_node = output_name_to_node[input_name] + if parent_node and parent_node.op_type == "Cast": + inputs = [parent_node.input[0]] + + cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name]) + cast_node.attribute.extend([helper.make_attribute("to", to_type)]) - self.model.add_node(cast_node) + self.model.add_node(cast_node, graph_name=graph_name) - return cast_output, cast_node + return cast_node def cast_input_to_int32(self, input_name: str): return self.cast_input(input_name, "int32") @@ -224,9 +240,10 @@ def check_node_input_value(self, node, input_index: int, expected_value): def remove_identity_nodes(self): """Remove Identity nodes, except those right before graph output.""" nodes_to_remove = [] + graph_output_names = self.model.get_graphs_output_names() for node in self.model.nodes(): if node.op_type == "Identity": - if node.output[0] not in self.model.get_graphs_output_names(): + if node.output[0] not in graph_output_names: self.model.replace_input_of_all_nodes(node.output[0], node.input[0]) nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/import_utils.py b/onnxruntime/python/tools/transformers/import_utils.py new file mode 100644 index 0000000000000..9755a26b7b004 --- /dev/null +++ b/onnxruntime/python/tools/transformers/import_utils.py @@ -0,0 +1,20 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import importlib.metadata +import importlib.util + + +def is_installed(package): + try: + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False + + return spec is not None + + return dist is not None diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index b10c10c87ee57..8607485bc265b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -51,7 +51,7 @@ sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_ve --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ --allow_running_as_root python3 -m pip install --upgrade pip -python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall +python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-*.whl --force-reinstall ``` If the GPU is not A100, change `CMAKE_CUDA_ARCHITECTURES=80` in the command line according to the GPU compute capacity (like 89 for RTX 4090, or 86 for RTX 3090). diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 37b39c91b5c15..9d1066b6e372b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -40,6 +40,12 @@ def initialize(self, model): self.enable_shape_infer: bool = True self.all_graphs: Optional[List[GraphProto]] = None + # Cache of shape and data type from onnx graph to speed up optimization. + # Be careful that fusion shall not reuse node output name for different shape/type (in adding/removing nodes) + # Note that these do not cache the symbolic shape inference result. + self._dtype_dict: Optional[Dict[str, int]] = None + self._shape_dict: Optional[Dict[str, List]] = None + def disable_shape_inference(self): self.enable_shape_infer = False @@ -519,20 +525,60 @@ def tensor_shape_to_list(self, tensor_type): shape_list.append("?") # shall not happen return shape_list - def get_dtype(self, input_or_output: str): - """Try get data type given a name (could be initializer, graph input or output).""" - tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info} + def get_dtype(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get data type given a name (could be initializer, input or output of graph or node).""" + + if self._dtype_dict is None: + self._dtype_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + self._dtype_dict[value_info.name] = value_info.type.tensor_type.elem_type + + for initializer in self.model.graph.initializer: + if initializer.name not in self._dtype_dict: + self._dtype_dict[initializer.name] = initializer.data_type - if input_or_output in tensor_type_map: - return tensor_type_map[input_or_output].tensor_type.elem_type + if name in self._dtype_dict: + return self._dtype_dict[name] - graph_input = self.find_graph_input(input_or_output) - if graph_input: - return graph_input.type.tensor_type.elem_type + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type + + return None - graph_output = self.find_graph_output(input_or_output) - if graph_output: - return graph_output.type.tensor_type.elem_type + def get_shape(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get shape given a name (could be initializer, input or output of graph or node).""" + + if self._shape_dict is None: + self._shape_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + if value_info.type.tensor_type.HasField("shape"): + shape = [] + for dim in value_info.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + self._shape_dict[value_info.name] = shape + + for initializer in self.model.graph.initializer: + if initializer.name not in self._shape_dict: + self._shape_dict[initializer.name] = initializer.dims + + if name in self._shape_dict: + return self._shape_dict[name] + + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type return None @@ -566,23 +612,14 @@ def remove_cascaded_cast_nodes(self): def remove_useless_cast_nodes(self): """Remove cast nodes that are not needed: input and output has same data type.""" shape_infer = self.infer_runtime_shape(update=True) - if shape_infer is None: - logger.info("Skip removing useless cast nodes since shape inference failed.") - return - - def get_data_type(input_or_output_name): - dtype = self.get_dtype(input_or_output_name) - if dtype: - return dtype - if shape_infer.known_vi_[input_or_output_name].type.tensor_type.HasField("elem_type"): - return shape_infer.known_vi_[input_or_output_name].type.tensor_type.elem_type - return None + if self.enable_shape_infer and shape_infer is None: + logger.warning("shape inference failed which might impact useless cast node detection.") nodes_to_remove = [] for node in self.nodes(): if node.op_type == "Cast": - input_dtype = get_data_type(node.input[0]) - output_dtype = get_data_type(node.output[0]) + input_dtype = self.get_dtype(node.input[0], shape_infer) + output_dtype = self.get_dtype(node.output[0], shape_infer) if input_dtype and input_dtype == output_dtype: nodes_to_remove.append(node) @@ -601,7 +638,10 @@ def get_data_type(input_or_output_name): self.replace_input_of_all_nodes(node.output[0], node.input[0]) self.remove_node(node) - logger.info("Removed %d Cast nodes with output type same as input", len(nodes_to_remove)) + logger.info( + "Removed %d Cast nodes with output type same as input", + len(nodes_to_remove), + ) def convert_model_float32_to_float16(self, cast_input_output=True): logger.warning( @@ -1214,7 +1254,10 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): continue for j in range(i + 1, initializer_count): if OnnxModel.has_same_value( - self.model.graph.initializer[i], self.model.graph.initializer[j], cache, cache + self.model.graph.initializer[i], + self.model.graph.initializer[j], + cache, + cache, ): same[j] = i @@ -1223,7 +1266,8 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): if same[i] >= 0: count += 1 self.replace_input_of_all_nodes( - self.model.graph.initializer[i].name, self.model.graph.initializer[same[i]].name + self.model.graph.initializer[i].name, + self.model.graph.initializer[same[i]].name, ) if count > 0: diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 51deb67ce5bf3..431e64509e3cc 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -126,7 +126,8 @@ def fuse_rotary_embeddings(self): # Remove non-MS domain functions rot_emb_nodes = list( filter( - lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", self.model.graph.node + lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", + self.model.graph.node, ) ) non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes)) @@ -350,7 +351,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo self.attention_mask.set_mask_format(options.attention_mask_format) if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention): self.attention_fusion = FusionAttention( - self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention + self, + self.hidden_size, + self.num_heads, + self.attention_mask, + options.use_multi_head_attention, ) if (options is None) or options.enable_attention: @@ -415,7 +420,12 @@ def get_fused_operator_statistics(self): "SkipSimplifiedLayerNormalization", "RotaryEmbedding", ] - q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"] + q_ops = [ + "QOrderedAttention", + "QOrderedGelu", + "QOrderedLayerNormalization", + "QOrderedMatMul", + ] for op in ops + q_ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 4d15b9288e7b6..01298b3576eb1 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from logging import getLogger +import logging from typing import Optional from fusion_attention_unet import FusionAttentionUnet @@ -14,11 +14,12 @@ from fusion_options import FusionOptions from fusion_skip_group_norm import FusionSkipGroupNorm from fusion_transpose import FusionInsertTranspose, FusionTranspose +from import_utils import is_installed from onnx import ModelProto from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel -logger = getLogger(__name__) +logger = logging.getLogger(__name__) class UnetOnnxModel(BertOnnxModel): @@ -94,14 +95,24 @@ def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): # Self Attention enable_packed_qkv = (options is None) or options.enable_packed_qkv self_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, False, enable_packed_qkv, False + self, + self.hidden_size, + self.num_heads, + is_cross_attention=False, + enable_packed_qkv=enable_packed_qkv, + enable_packed_kv=False, ) self_attention_fusion.apply() # Cross Attention enable_packed_kv = (options is None) or options.enable_packed_kv cross_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, True, False, enable_packed_kv + self, + self.hidden_size, + self.num_heads, + is_cross_attention=True, + enable_packed_qkv=False, + enable_packed_kv=enable_packed_kv, ) cross_attention_fusion.apply() @@ -110,23 +121,48 @@ def fuse_bias_add(self): fusion.apply() def optimize(self, options: Optional[FusionOptions] = None): + if is_installed("tqdm"): + import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm(): + steps = 18 + progress_bar = tqdm.tqdm(range(0, steps), initial=0, desc="fusion") + self._optimize(options, progress_bar) + else: + logger.info("tqdm is not installed. Run optimization without progress bar") + self._optimize(options, None) + + def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() self.utils.remove_identity_nodes() + if progress_bar: + progress_bar.update(1) # Remove cast nodes that having same data type of input and output based on symbolic shape inference. self.utils.remove_useless_cast_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_gelu: self.fuse_gelu() + if progress_bar: + progress_bar.update(1) self.preprocess() + if progress_bar: + progress_bar.update(1) self.fuse_reshape() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_group_norm: channels_last = (options is None) or options.group_norm_channels_last @@ -135,42 +171,66 @@ def optimize(self, options: Optional[FusionOptions] = None): insert_transpose_fusion = FusionInsertTranspose(self) insert_transpose_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_splitgelu: bias_split_gelu_fusion = FusionBiasSplitGelu(self) bias_split_gelu_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_attention: + # self.save_model_to_file("before_mha.onnx") self.fuse_multi_head_attention(options) + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() + if progress_bar: + progress_bar.update(1) self.fuse_shape() + if progress_bar: + progress_bar.update(1) # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_group_norm: skip_group_norm_fusion = FusionSkipGroupNorm(self) skip_group_norm_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_gelu_approximation: self.gelu_approximation() + if progress_bar: + progress_bar.update(1) if options is None or options.enable_nhwc_conv: self.convert_conv_to_nhwc() - self.merge_adjacent_transpose() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_bias_add: self.fuse_bias_add() + if progress_bar: + progress_bar.update(1) self.postprocess() + if progress_bar: + progress_bar.update(1) logger.info(f"opset version: {self.get_opset_version()}") @@ -190,6 +250,7 @@ def get_fused_operator_statistics(self): "NhwcConv", "BiasAdd", ] + for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) From 61610ff9862ad834f153ed3e70ba526dac86ae7c Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 24 Jan 2024 00:25:05 +0800 Subject: [PATCH 15/45] [js/webgpu] Add FusedConv clip test case (#18900) Bug: https://github.com/microsoft/onnxruntime/issues/18899 --- js/web/test/data/ops/fused-conv.jsonc | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 812e9d7c2def0..ad1c0a72c11d3 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -108,5 +108,39 @@ ] } ] + }, + { + "name": "fused conv with clip", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Clip", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [400.0, 600.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [400, 470, 600, 600], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] } ] From 0ea48fc73ec6bdbb8af2010483a61823fcf1a613 Mon Sep 17 00:00:00 2001 From: Heflin Stephen Raj Date: Tue, 23 Jan 2024 23:40:54 +0530 Subject: [PATCH 16/45] Modified the condition to load the optimiser model (#18891) --- java/src/main/native/ai_onnxruntime_OrtTrainingSession.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index 9f7b8d3a3dcfc..464234c34798a 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes } } wchar_t* optimizerStr = NULL; - if (optimizerPath == NULL) { + if (optimizerPath != NULL) { optimizerStr = copyAndPad(jniEnv, optimizerPath); if (optimizerStr == NULL) { // exception has been thrown in Java, go to cleanup and return null. From 54871a27736cf54cbda9c4f09bb27e931de7334e Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 24 Jan 2024 02:49:24 +0800 Subject: [PATCH 17/45] Replace T4 to A10 in Linux GPU workflow (#19205) ### Description 1. Update Linux GPU machine from T4 to A10, sm=8.6 2. update the tolerance ### Motivation and Context 1. Free more T4 and test with higher compute capability. 2. ORT enables TF32 in GEMM for A10/100. TF32 will cause precsion loss and fail this test ``` 2024-01-19T13:27:18.8302842Z [ RUN ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_SSD_ssd12 2024-01-19T13:27:25.8438153Z /onnxruntime_src/onnxruntime/test/providers/cpu/model_tests.cc:347: Failure 2024-01-19T13:27:25.8438641Z Expected equality of these values: 2024-01-19T13:27:25.8438841Z COMPARE_RESULT::SUCCESS 2024-01-19T13:27:25.8439276Z Which is: 4-byte object <00-00 00-00> 2024-01-19T13:27:25.8439464Z ret.first 2024-01-19T13:27:25.8445514Z Which is: 4-byte object <01-00 00-00> 2024-01-19T13:27:25.8445962Z expected 0.145984 (3e157cc1), got 0.975133 (3f79a24b), diff: 0.829149, tol=0.0114598 idx=375. 20 of 388 differ 2024-01-19T13:27:25.8446198Z 2024-01-19T13:27:25.8555736Z [ FAILED ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_SSD_ssd12, where GetParam() = "cuda_../models/zoo/opset12/SSD/ssd-12.onnx" (7025 ms) 2024-01-19T13:27:25.8556077Z [ RUN ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_YOLOv312_yolov312 2024-01-19T13:27:29.3174318Z /onnxruntime_src/onnxruntime/test/providers/cpu/model_tests.cc:347: Failure 2024-01-19T13:27:29.3175144Z Expected equality of these values: 2024-01-19T13:27:29.3175389Z COMPARE_RESULT::SUCCESS 2024-01-19T13:27:29.3175812Z Which is: 4-byte object <00-00 00-00> 2024-01-19T13:27:29.3176080Z ret.first 2024-01-19T13:27:29.3176322Z Which is: 4-byte object <01-00 00-00> 2024-01-19T13:27:29.3178431Z expected 4.34958 (408b2fb8), got 4.51324 (40906c80), diff: 0.16367, tol=0.0534958 idx=9929. 22 of 42588 differ ``` 3. some other test like SSD throw other exception, so skip them ''' 2024-01-22T09:07:40.8446910Z [ RUN ] ModelTests/ModelTest.Run/cuda__models_zoo_opset12_SSD_ssd12 2024-01-22T09:07:51.5587571Z /onnxruntime_src/onnxruntime/test/providers/cpu/model_tests.cc:358: Failure 2024-01-22T09:07:51.5588512Z Expected equality of these values: 2024-01-22T09:07:51.5588870Z COMPARE_RESULT::SUCCESS 2024-01-22T09:07:51.5589467Z Which is: 4-byte object <00-00 00-00> 2024-01-22T09:07:51.5589953Z ret.first 2024-01-22T09:07:51.5590462Z Which is: 4-byte object <01-00 00-00> 2024-01-22T09:07:51.5590841Z expected 1, got 63 ''' --- .../test/global_thread_pools/test_inference.cc | 8 +++++++- onnxruntime/test/providers/cpu/model_tests.cc | 17 +++++++++++++++++ .../providers/cuda/nhwc/conv_transpose_test.cc | 6 +++++- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 4 ++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index 4772e7de2bdd7..f553682975f11 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -55,9 +55,15 @@ static void RunSession(OrtAllocator& allocator, Ort::Session& session_object, // size_t total_len = type_info.GetElementCount(); ASSERT_EQ(values_y.size(), static_cast(5)); +// test inference is using onnxruntime_shared_lib_test_LIBS, so HasCudaEnvironment(800) isn't available +#ifdef USE_CUDA + const float tolerance = 1e-5f; +#else + const float tolerance = 1e-6f; +#endif OutT* f = output_tensor->GetTensorMutableData(); for (size_t i = 0; i != static_cast(5); ++i) { - ASSERT_NEAR(values_y[i], f[i], 1e-6f); + ASSERT_NEAR(values_y[i], f[i], tolerance); } } diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 859e082716760..8128c170c5211 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -39,6 +39,8 @@ #include "core/providers/armnn/armnn_provider_factory.h" #endif +#include "test/common/cuda_op_test_utils.h" + // test infrastructure #include "test/onnx/testenv.h" #include "test/onnx/TestCase.h" @@ -94,6 +96,21 @@ TEST_P(ModelTest, Run) { std::unique_ptr model_info = std::make_unique(model_path.c_str()); +#if defined(__linux__) + // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. + if (HasCudaEnvironment(800) && provider_name == "cuda") { + per_sample_tolerance = 1e-1; + if (model_path.find(ORT_TSTR("SSD")) > 0 || + model_path.find(ORT_TSTR("ssd")) > 0 || + model_path.find(ORT_TSTR("yolov3")) > 0 || + model_path.find(ORT_TSTR("mask_rcnn")) > 0 || + model_path.find(ORT_TSTR("FNS")) > 0) { + SkipTest("Skipping SSD test for big tolearance failure or other errors"); + return; + } + } +#endif + if (model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { SkipTest("it has the training domain. No pipeline should need to run these tests."); diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 06da2a5304716..6514feadf0ff7 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -70,7 +70,11 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { auto op = ConvTransposeOp{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true}; - MAKE_PROVIDERS_EPS_TYPE(TypeParam) + if (HasCudaEnvironment(800)) { + MAKE_PROVIDERS_EPS(1e-2) + } else { + MAKE_PROVIDERS_EPS_TYPE(TypeParam) + } } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 1060a0138e0b7..5779b1da3fd43 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -137,7 +137,7 @@ jobs: --enable_cuda_profiling --enable_cuda_nhwc_ops \ --enable_pybind --build_java \ --use_cache \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75; \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) @@ -166,7 +166,7 @@ jobs: skipComponentGovernanceDetection: true workspace: clean: all - pool: Onnxruntime-Linux-GPU-T4 + pool: onnxruntime-Linux-GPU-A10 dependsOn: - Linux_Build steps: From f53068446e7e560012862e1812270bcf908fbda4 Mon Sep 17 00:00:00 2001 From: petermcaughan Date: Tue, 23 Jan 2024 13:44:34 -0800 Subject: [PATCH 18/45] Add Temperature to WhisperBeamSearch input (#19188) ### Description Add `temperature` as an input to WhisperBeamSearch op and initialize correctly in parameter setup. ### Motivation and Context Currently, temperature is included as an attribute to the BeamSearch op, which doesn't let the model act dynamically in a single inference session. By including this variable as an input, the temperature value can be altered in any inference call (important for 1P teams) --------- Co-authored-by: Peter McAughan Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: Kunal Vaishnavi --- docs/ContribOperators.md | 4 +++- docs/OperatorKernels.md | 4 ++-- .../cpu/transformers/beam_search_parameters.cc | 14 +++++++++++++- .../contrib_ops/cuda/transformers/beam_search.cc | 1 + onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 1 + 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 22e82443167f6..624cda1d37f73 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5761,7 +5761,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (5 - 14) +#### Inputs (5 - 15)
input_ids : F
@@ -5792,6 +5792,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+
temperature (optional) : T
+
Temperature value to apply to logits processing during this execution's decoding. Shape is (1)
#### Outputs (1 - 5) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9a2a7ac89bbb3..3b695af2839b6 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -499,7 +499,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| | | | | @@ -876,7 +876,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 3962486d5b5eb..bb6885c3216bc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -123,8 +123,20 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0; ORT_ENFORCE(logits_processor >= 0, "logits_processor shall be a non-negative integer, got ", logits_processor); -} + if (this->model_type == IGenerationParameters::kModelTypeWhisper) { + auto* temperature_tensor = context->Input(14); + if (temperature_tensor) { + if (temperature_tensor->IsDataType()) { + temperature = *temperature_tensor->Data(); + } else { + temperature = static_cast(*temperature_tensor->Data()); + } + } else { + temperature = 1.0f; + } + } +} void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { // Override vocab_size using the inferred shape from the decoder subgraph ONLY IF // the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch) diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 2a90e4911f286..08cbb145a6f65 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -49,6 +49,7 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 982e8fd834b76..27c968a59eb91 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1231,6 +1231,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " "are treated as stop of the extra_decoding_ids for corresponding batch.", "I", OpSchema::Optional) + .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", From 532f8c642ce9c1ea2971b7d0f0ff8a4197bcb3a0 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Tue, 23 Jan 2024 14:57:30 -0800 Subject: [PATCH 19/45] Fix a backend test by using local backend (#19230) The decomposition pass (e.g., converting torch.add to aten.add) in DORT no longer exists. Therefore, we have to use `use_aot_autograd=True` to enable Dynamo's built-in operator decomposition. I think we need to add the decomposition pass back to DORT or remove `use_aot_autograd` (remove because it will always be `true`). --- .../orttraining/test/python/orttraining_test_dort.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index f0b6b9c5fba28..573ec85d76013 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -216,7 +216,12 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.relu() return tensor_q - local_backend = make_local_backend(dynamic=True, use_aot_autograd=False) + # TODO: Set use_aot_autograd=False. In order to decompose torch + # function calls to aten ops, we need to set + # user_aot_autograd=True because there is no decomposition in DORT + # anymore. A long-term fix will be brining # decomposition pass back + # into DORT. + local_backend = make_local_backend(dynamic=True, use_aot_autograd=True) optimized_elementwise_model = torch.compile(elementwise_model, backend=local_backend, dynamic=True) def run(fun, list_x): From cbb29d80ff5ec63d3cc2289911c4420f5a9d8a2d Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Tue, 23 Jan 2024 16:34:26 -0800 Subject: [PATCH 20/45] GQA Rotary and Packed QKV with Flash (#18906) ### Description These changes add rotary embedding and packed qkv input to gqa. As of now, the changes are only supported with Flash-Attention (SM >= 80) but should soon be supported with Memory Efficient Attention as well. ### Motivation and Context With the fusion of rotary embedding into this Attention op, we hope to observe some perf gain. The packed QKV should also provide some perf gain in the context of certain models, like Llama2, that would benefit from running ops on the fused QKV matrix, rather than the separate Q, K, and V. --------- Co-authored-by: Yufeng Li --- docs/ContribOperators.md | 16 +- docs/OperatorKernels.md | 2 +- .../contrib_ops/cpu/bert/attention_common.h | 5 + .../cuda/bert/flash_attention/flash_api.cc | 51 +- .../cuda/bert/flash_attention/flash_api.h | 6 +- .../cuda/bert/group_query_attention.cc | 26 +- .../cuda/bert/group_query_attention.h | 5 + .../cuda/bert/group_query_attention_helper.h | 150 ++-- .../cuda/bert/group_query_attention_impl.cu | 125 ++-- .../cuda/bert/group_query_attention_impl.h | 2 + .../core/graph/contrib_ops/bert_defs.cc | 34 +- .../test/python/transformers/rotary_flash.py | 693 ++++++++++++++++++ .../python/transformers/test_flash_attn.py | 668 ++++++++++++++--- tools/ci_build/build.py | 3 +- ...txt => requirements-transformers-test.txt} | 3 +- 15 files changed, 1517 insertions(+), 272 deletions(-) create mode 100644 onnxruntime/test/python/transformers/rotary_flash.py rename tools/ci_build/{requirements.txt => requirements-transformers-test.txt} (94%) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 624cda1d37f73..e7b537d6894c8 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
do_rotary : int
+
Whether to use rotary position embedding. Default value is 0.
kv_num_heads : int (required)
Number of attention heads for k and v
local_window_size : int
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
rotary_interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs +#### Inputs (7 - 9)
query : T
-
Query with shape (batch_size, sequence_length, hidden_size)
-
key : T
+
Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+
key (optional) : T
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
-
value : T
+
value (optional) : T
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
past_key (optional) : T
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
@@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
total_sequence_length : M
Scalar tensor of total sequence length (past + new).
+
cos_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 3b695af2839b6..31cca232fde34 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -843,7 +843,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index da489a6901512..8afeb874750b4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters { bool is_unidirectional; // causal int local_window_size; bool kv_share_buffer; + bool is_packed_qkv; bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool do_rotary; + bool rotary_interleaved; float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; + int zeros_count; + int* zero_ptr; }; namespace attention { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index d6eb87228bb4a..2c296bf4f8483 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size - void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size) { - // if (seqlen_q == 1) { - // is_causal = false; - // } // causal=true is the same as causal=false in this case - + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + // In kv-cache case, seqlen_k_max as kv sequence length Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, is_causal ? 0 : -1); params.dprops = &dprops; - if (k != nullptr && v != nullptr) { + if (k_new != nullptr && v_new != nullptr) { params.seqlen_knew = seqlen_k_new; - params.knew_ptr = k; - params.vnew_ptr = v; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; // All stride are in elements, not bytes. - params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.knew_row_stride = num_heads_k * head_size; - params.vnew_row_stride = num_heads_k * head_size; + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } params.knew_head_stride = head_size; params.vnew_head_stride = head_size; } else { @@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.cu_seqlens_k = static_cast(seqlens_k_); } + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = (head_size / 16) * 16; + } + params.num_splits = num_splits; if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { params.softmax_lseaccum_ptr = softmax_lse_accum; @@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr); + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 3d75d6834b8e0..387d1cf9d84fe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size = -1); + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index fd6fb79742cac..fe56f84f0a886 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -47,6 +47,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) kv_num_heads_ = static_cast(kv_num_heads); is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -62,6 +64,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_memory_efficient_attention_ = true; #endif + if (!disable_flash_attention_) { + zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); + } } template @@ -73,6 +78,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); const Tensor* total_seqlen = context->Input(6); + const Tensor* cos_cache = context->Input(7); + const Tensor* sin_cache = context->Input(8); auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; @@ -84,6 +91,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { value, past_key, past_value, + cos_cache, + sin_cache, ¶meters, num_heads_, kv_num_heads_, @@ -93,7 +102,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { scale_, device_prop.maxThreadsPerBlock)); parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + parameters.zeros_count = kZerosCount; + parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; int sequence_length = parameters.sequence_length; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); @@ -139,6 +154,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && + do_rotary_ == false && + key != nullptr && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -182,8 +199,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); - data.key = reinterpret_cast(key->Data()); - data.value = reinterpret_cast(value->Data()); + data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); + data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.output = reinterpret_cast(output->MutableData()); @@ -229,6 +246,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } + // Rotary + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast(cos_cache->Data()); + data.sin_cache = reinterpret_cast(sin_cache->Data()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 54a8127e29e7b..15573ece166fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel { int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention int local_window_size_; + bool is_unidirectional_; bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; float scale_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) + IAllocatorUniquePtr zeros_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 2cb9955807f26..853e1a710cb24 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query, bool is_past_bsnh, float scale) { // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr // no packing for q/k/v: - // query (Q) : (B, S, D) - // key (K) : (B, S, D_kv) - // value (V) : (B, S, D_kv) + // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) + // key (K) : (B, S, D_kv) or nullptr + // value (V) : (B, S, D_kv) or nullptr ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - + const bool is_packed_qkv = key == nullptr; const auto& query_dims = query->Shape().GetDims(); - const auto& key_dims = key->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query, int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); int q_hidden_size = static_cast(query_dims[2]); - int head_size = static_cast(q_hidden_size) / num_heads; + int head_size = 0; + + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - int kv_hidden_size = static_cast(key_dims[2]); + int kv_hidden_size = 0; + // Check key and value when not packed + if (!is_packed_qkv) { + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } else if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != key_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 1 (sequence length)"); + } + kv_hidden_size = static_cast(key_dims[2]); + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } else if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 1 (sequence length)"); + } else if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + } else { + // Check packed qkv + head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + } + // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { const auto& past_key_dims = past_key->Shape().GetDims(); @@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (static_cast(sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)"); - } - - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - // Check seqlens_k tensor (holding past seqlen for token gen) const auto& seqlens_dim = seqlens_k->Shape().GetDims(); if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { @@ -180,6 +205,36 @@ Status CheckInputs(const Tensor* query, int total_sequence_length = *((*total_seqlen).template Data()); int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + if (cos_cache != nullptr && sin_cache != nullptr) { + const auto& cos_dims = cos_cache->Shape().GetDims(); + const auto& sin_dims = sin_cache->Shape().GetDims(); + + if (head_size % 16 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size shall be a multiple of 16. Got head_size % 16 == ", + head_size % 16); + } + if (cos_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 0 must be of present_sequence_length."); + } + if (sin_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 0 must be of present_sequence_length."); + } + if (cos_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (sin_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + } else if (cos_cache != nullptr || sin_cache != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); + } + bool is_prompt = sequence_length != 1; if (parameters != nullptr) { @@ -190,9 +245,10 @@ Status CheckInputs(const Tensor* query, output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; - output_parameters->head_size = q_hidden_size / num_heads; + output_parameters->head_size = head_size; output_parameters->kv_hidden_size = kv_hidden_size; output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; @@ -208,6 +264,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -220,7 +278,7 @@ Status CheckInputs(const Tensor* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); } } // namespace group_query_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 5b0f5d0cfe601..d88e9a49fb5ee 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -151,9 +151,10 @@ template Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, cudaStream_t stream, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.sequence_length; + const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; @@ -441,7 +442,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, return CUDA_CALL(cudaGetLastError()); } - __global__ void PastToTotalSeqlen(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int add_seqlen) { @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int threads_per_block) { if (parameters.is_prompt) { return Status::OK(); } @@ -482,91 +482,63 @@ Status FlashAttention( const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.sequence_length; - const int present_sequence_length = parameters.seqlen_present_kv_cache; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = true; - bool is_bf16 = std::is_same::value; - // Note: seqlens_k is past sequence length for flash - if (parameters.is_prompt) { - // Launch kernel to copy seqlen - constexpr int thr_per_blk = 256; - int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); - } - - void* seqlens_k = reinterpret_cast(data.seqlens_k); - - if (parameters.kv_share_buffer) { - // Share buffer case - if (data.past_key == nullptr || data.past_key != data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv shall share the same tensor when kv_share_buffer is on."); - } - - if (parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); - key = nullptr; - value = nullptr; - seqlens_k = reinterpret_cast(data.seqlens_k_total); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + void* query = reinterpret_cast(const_cast(data.query)); + void* key; + void* value; - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + if (!parameters.is_packed_qkv) { + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); } else { - // Not share buffer case - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (data.past_key != nullptr && data.past_key == data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv share the same tensor but kv_share_buffer is not on."); - } - - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; + } - if (!parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + void* seqlens_k = reinterpret_cast(data.seqlens_k); + if (parameters.is_prompt) { + // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value + // user should use seqlens_k to index into output to get new tokens + if (batch_size <= parameters.zeros_count) { + seqlens_k = parameters.zero_ptr; + } else { + // Launch kernel to create larger seqlen tensor when batch_size > 256 + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); + seqlens_k = data.seqlens_k_total; } - - seqlens_k = reinterpret_cast(data.seqlens_k_total); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); - DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size); - DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size); - DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, 0, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); } + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, + batch_size, num_heads, kv_num_heads, head_size, sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, + scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, + parameters.is_packed_qkv)); + + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -672,7 +644,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index de32d7ea93163..1bf91f9c875eb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -21,6 +21,8 @@ struct GroupQueryAttentionData { const T* past_key = nullptr; const T* past_value = nullptr; int* seqlens_k = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7f34647f1faef..8583474a1e391 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -259,13 +259,13 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); - } else { - fail_shape_inference("Missing input 2 (value)"); } } if (ctx.getNumOutputs() > 1) { // has present output if (hasInputShape(ctx, past_key_index)) { + // auto& query_shape = getInputShape(ctx, 0); + // auto& query_dims = query_shape.dim(); auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); if (past_dims.size() != 4) { @@ -273,8 +273,7 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& } ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not } } } @@ -1015,18 +1014,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", AttributeProto::INT, static_cast(-1)) + .Attr("do_rotary", + "Whether to use rotary position embedding. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("rotary_interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", - "Query with shape (batch_size, sequence_length, hidden_size)", + "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" + "(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).", "T") .Input(1, "key", "Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", - "T") + "T", + OpSchema::Optional) .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", - "T") + "T", + OpSchema::Optional) .Input(3, "past_key", "past state key with support for format BNSH. When past_key uses same tensor as present_key" @@ -1047,6 +1057,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "total_sequence_length", "Scalar tensor of total sequence length (past + new).", "M") + .Input(7, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) + .Input(8, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/test/python/transformers/rotary_flash.py b/onnxruntime/test/python/transformers/rotary_flash.py new file mode 100644 index 0000000000000..42bff9c92b41b --- /dev/null +++ b/onnxruntime/test/python/transformers/rotary_flash.py @@ -0,0 +1,693 @@ +# Copyright (c) 2023, Tri Dao. + + +from typing import Optional, Tuple, Union + +import torch +import triton +import triton.language as tl +from einops import rearrange, repeat + +##### TRITON KERNEL FOR ROTARY ##### + + +# @triton.autotune( +# configs=[ +# triton.Config({"block_m": 2}), +# triton.Config({"block_m": 4}), +# triton.Config({"block_m": 8}), +# triton.Config({"block_m": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], +# ) +@triton.jit +def rotary_kernel( + out_, # Pointers to matrices + x_, + cos_, + sin_, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + block_k: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + block_m: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_ = x_ + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ = out_ + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_ = x_ + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ = out_ + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * block_m >= seqlen: + return + rm = pid_m * block_m + tl.arange(0, block_m) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, block_k) + rk_half = tl.arange(0, block_k // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of x_, do calculation, then store to 1st and 2nd halves of out_ + x_ = x_ + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load(cos_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to( + tl.float32 + ) + sin = tl.load(sin_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to( + tl.float32 + ) + x0 = tl.load(x_, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) + x1 = tl.load( + x_ + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) + tl.store(out_, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store( + out_ + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load x_[0, 2, 4, ...] and x_[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = x_[0, 1, 2, 3, ...] and x1 = x_[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = cos_[0, 0, 1, 1, ...] and sin = sin_[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, block_k) // 2 + x0_ = x_ + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + x1_ = x_ + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + cos_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + sin_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(x0_, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32) + x1 = tl.load(x1_, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) + tl.store(out_, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + block_k = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + grid = lambda META: (triton.cdiv(seqlen, META["block_m"]), batch, nheads) # noqa + block_m = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # key for triton cache (limit number of compilations) + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + block_k, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + block_m, + ) + return output + + +##### ROTARY API ##### + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb + + +class ApplyRotaryEmbQKV(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) + apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + q, k = qkv[:, :, 0], qkv[:, :, 1] + apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) + apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cos_k, sin_k = ctx.saved_tensors + if cos_k is None and sin_k is None and dqkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] + apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True) + apply_rotary( + dk, + cos_k, + sin_k, + seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dqkv, None, None, None, None, None, None + + +def apply_rotary_emb_qkv_( + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + qkv: (batch_size, seqlen, 3, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of Q and K. + """ + return ApplyRotaryEmbQKV.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets) + + +class ApplyRotaryEmbKV(torch.autograd.Function): + @staticmethod + def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): + batch, seqlen, two, nheads, headdim = kv.shape + assert two == 2 + k = kv[:, :, 0] + apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return kv + + @staticmethod + def backward(ctx, dkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, seqlen_offsets = ctx.saved_tensors + else: + cos, sin = ctx.saved_tensors + apply_rotary( + dkv[:, :, 0], + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dkv, None, None, None, None + + +apply_rotary_emb_kv_ = ApplyRotaryEmbKV.apply + + +def apply_rotary_emb_kv_( + kv, + cos, + sin, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + kv: (batch_size, seqlen, 2, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + kv: (batch_size, seqlen, 2, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of K. + """ + return ApplyRotaryEmbKV.apply(kv, cos, sin, interleaved, seqlen_offsets) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + qkv: torch.Tensor, + kv: Optional[torch.Tensor] = None, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = qkv.shape[1] + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + if kv is None: + if self.scale is None: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + q = qkv + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + inplace=True, + seqlen_offsets=seqlen_offset, + ) + if self.scale is None: + kv = apply_rotary_emb_kv_( + kv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + kv = apply_rotary_emb_kv_( + kv, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + return q, kv diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 8a839875de2a2..90d28872d3cc8 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -20,6 +20,7 @@ from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper +from rotary_flash import apply_rotary_emb from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -184,7 +185,13 @@ def create_multihead_attention_graph(config): def create_group_query_attention_graph_prompt( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -193,18 +200,22 @@ def create_group_query_attention_graph_prompt( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key" if share_buffer else "", "past_value" if share_buffer else "", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -218,25 +229,9 @@ def create_group_query_attention_graph_prompt( [ config.batch_size, config.q_sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -250,6 +245,27 @@ def create_group_query_attention_graph_prompt( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] if share_buffer: graph_input += [ helper.make_tensor_value_info( @@ -273,6 +289,25 @@ def create_group_query_attention_graph_prompt( ], ), ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -334,7 +369,13 @@ def create_group_query_attention_graph_prompt( def create_group_query_attention_graph_past( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -345,18 +386,22 @@ def create_group_query_attention_graph_past( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key", "past_value", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -370,25 +415,9 @@ def create_group_query_attention_graph_past( [ config.batch_size, config.sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -411,8 +440,6 @@ def create_group_query_attention_graph_past( config.head_size, ], ), - ] - graph_input += [ helper.make_tensor_value_info( "seqlens_k", TensorProto.INT32, @@ -424,6 +451,46 @@ def create_group_query_attention_graph_past( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -663,21 +730,38 @@ def mha_func(q, k, v, config): def gqa_prompt_func( - q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + window_size=-1, + past_kv_format=Formats.BSNH, + share_buffer=True, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_prompt( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None - new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -686,9 +770,17 @@ def gqa_prompt_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -713,17 +805,23 @@ def gqa_prompt_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) io_binding.bind_output("output") @@ -737,21 +835,38 @@ def gqa_prompt_func( def gqa_past_func( - q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1 + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + past_kv_format=Formats.BSNH, + share_buffer=True, + window_size=-1, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_past( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() - new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -763,9 +878,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -790,8 +913,6 @@ def gqa_past_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": past_k.detach().cpu().numpy(), "past_value": past_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -805,9 +926,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) @@ -1029,9 +1158,12 @@ def parity_check_mha( def parity_check_gqa_prompt( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1080,6 +1212,8 @@ def parity_check_gqa_prompt( dtype=torch.float16, requires_grad=False, ) + # print(k.shape) + # print(new_k.shape) window_size = (-1, -1) left_window_size = -1 @@ -1105,19 +1239,47 @@ def parity_check_gqa_prompt( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1125,13 +1287,47 @@ def parity_check_gqa_prompt( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # print((out - out_ref)[0, :, 0, 0]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1139,10 +1335,16 @@ def parity_check_gqa_prompt( # Compare results print( "KV-buffer", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1171,9 +1373,12 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1229,13 +1434,42 @@ def parity_check_gqa_prompt_no_buff( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, k_cache_ref + k_cache_ref = k_ro + brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1243,9 +1477,39 @@ def parity_check_gqa_prompt_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + None, + None, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + None, + None, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1256,7 +1520,17 @@ def parity_check_gqa_prompt_no_buff( # Compare results print( - "KV-buffer", + "No buff", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1285,9 +1559,12 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1336,6 +1613,7 @@ def parity_check_gqa_past( dtype=torch.float16, requires_grad=False, ) + window_size = (-1, -1) left_window_size = -1 if local: @@ -1359,18 +1637,45 @@ def parity_check_gqa_past( dtype=torch.int32, device="cuda", ) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1378,13 +1683,46 @@ def parity_check_gqa_past( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, cache_seqlens[0], :]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1394,10 +1732,16 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, " B:", config.batch_size, " S:", @@ -1427,6 +1771,9 @@ def parity_check_gqa_past_no_buff( causal=False, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1503,18 +1850,47 @@ def parity_check_gqa_past_no_buff( device="cuda", ) cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = ( + torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + ) + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1522,13 +1898,47 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((out - out_ref)[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # Make sure past-present buffer updating correctly # assert numpy.allclose( # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True @@ -1540,10 +1950,16 @@ def parity_check_gqa_past_no_buff( # Compare results print( "NO buff", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1671,10 +2087,25 @@ def test_gqa_no_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, local=local, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1684,7 +2115,6 @@ def test_gqa_past(self): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- TEST GQA PAST (TOKEN GEN) ---------") - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( [(1, 128), (1, 1024), (1, 2048)] @@ -1706,6 +2136,7 @@ def test_gqa_past(self): num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") for b in batches: for s, s2 in seqs: for n, n2 in num_h: @@ -1734,23 +2165,30 @@ def test_gqa_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if __name__ == "__main__": diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 1034a82cb2854..6e5cd7b57e403 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2046,7 +2046,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): numpy_init_version = numpy.__version__ pb_init_version = google.protobuf.__version__ run_subprocess( - [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=SCRIPT_DIR + [sys.executable, "-m", "pip", "install", "-r", "requirements-transformers-test.txt"], + cwd=SCRIPT_DIR, ) run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it diff --git a/tools/ci_build/requirements.txt b/tools/ci_build/requirements-transformers-test.txt similarity index 94% rename from tools/ci_build/requirements.txt rename to tools/ci_build/requirements-transformers-test.txt index 57fc8f08336d2..a5279781462a7 100644 --- a/tools/ci_build/requirements.txt +++ b/tools/ci_build/requirements-transformers-test.txt @@ -3,7 +3,8 @@ packaging protobuf==3.20.2 numpy==1.24.0 ; python_version < '3.12' numpy==1.26.0 ; python_version >= '3.12' +torch coloredlogs==15.0 transformers==4.36.0 psutil -einops \ No newline at end of file +einops From 6a424ccf8c2f9cd7f191c843547d5f37ef409493 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Wed, 24 Jan 2024 03:33:49 +0000 Subject: [PATCH 21/45] Fix AMD pipeline test failures (#19250) ### Description Fix amd test failure ### Motivation and Context --- onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu | 5 +++-- onnxruntime/contrib_ops/rocm/bert/multihead_attention.h | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 6f98312e4067d..09e7d61b71db9 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -68,6 +68,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; using HipT = typename ToHipType::MappedType; using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; @@ -121,8 +122,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, - num_heads_, mask_filter_value_, scale_, + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h index 84d8b76bbfebe..1d676d7a7bcac 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public RocmKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_{false}; + bool is_unidirectional_{false}; // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. From c10be1848cafa7575ba298cbcc01e89dcd841851 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 23 Jan 2024 21:30:22 -0800 Subject: [PATCH 22/45] [TensorRT EP] Avoid calling unavailable function with cpu python package (#19251) C.register_tensorrt_plugins_as_custom_ops() is only available in gpu python package. Add condition to avoid calling it in cpu python package. --- .../python/onnxruntime_inference_collection.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 1a3e22142f80e..09f768f53ea65 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -466,7 +466,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi session_options = self._sess_options if self._sess_options else C.get_default_session_options() - self._register_ep_custom_ops(session_options, providers, provider_options) + self._register_ep_custom_ops(session_options, providers, provider_options, available_providers) if self._model_path: sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model) @@ -510,11 +510,15 @@ def _reset_session(self, providers, provider_options): self._sess_options = self._sess_options_initial self._create_inference_session(providers, provider_options) - def _register_ep_custom_ops(self, session_options, providers, provider_options): + def _register_ep_custom_ops(self, session_options, providers, provider_options, available_providers): for i in range(len(providers)): - if providers[i] == "TensorrtExecutionProvider": + if providers[i] in available_providers and providers[i] == "TensorrtExecutionProvider": C.register_tensorrt_plugins_as_custom_ops(session_options, provider_options[i]) - elif isinstance(providers[i], tuple) and providers[i][0] == "TensorrtExecutionProvider": + elif ( + isinstance(providers[i], tuple) + and providers[i][0] in available_providers + and providers[i][0] == "TensorrtExecutionProvider" + ): C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1]) From d7aebf9ea8a4a651088384f219292bae9062439b Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 24 Jan 2024 14:15:07 +0800 Subject: [PATCH 23/45] Move Nuget Test from T4 to A10 to reduce release duration (#19253) ### Description ### Motivation and Context Running release process is very painful and boring because some GPU jobs have to wait so long time. ![image](https://github.com/microsoft/onnxruntime/assets/16190118/1c5c981e-68d4-4678-9758-443fbf362802) ![image](https://github.com/microsoft/onnxruntime/assets/16190118/ba0d79ba-1554-4c7a-93dd-6ea8144c9295) ![image](https://github.com/microsoft/onnxruntime/assets/16190118/36cab833-71c1-4ff5-bca5-f4caa9aee0c9) On the one hand, we could move some T4 from PR process since some jobs are not using T4 any more and on the other hand, we can continue to change some jobs' agent from T4 to A4 too. In the future, T4 will mainly be used for the scenarioes that big GPU memory is needed, multiple GPU cards or some special cases. Test runs: https://dev.azure.com/aiinfra/Lotus/_build/results?buildId=401786&view=logs&j=8048494c-e6eb-5e47-5e87-ff0aa863325d cc @YUNQIUGUO @snnn --- .../c-api-noopenmp-packaging-pipelines.yml | 8 ++++---- .../github/azure-pipelines/cuda-packaging-pipeline.yml | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index aa1a75bfcda45..5a50a9964bead 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -1023,7 +1023,7 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -1034,7 +1034,7 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -1046,7 +1046,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' @@ -1055,7 +1055,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' MoreSuffix: '_Linux' diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index 1d2ba88652f48..0c24d4897ddf1 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -151,7 +151,7 @@ stages: # Testing - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -162,7 +162,7 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -174,7 +174,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' @@ -184,7 +184,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' MoreSuffix: '_Linux' From a39ac4a97976c9bea49be6e646ac1fac64278f65 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 24 Jan 2024 10:06:31 -0800 Subject: [PATCH 24/45] [DirectML] Register Pad19 (#19175) ### Description Register Pad19 in DirectML --------- Co-authored-by: Sheil Kumar --- .../src/Operators/DmlOperatorPadding.cpp | 7 +++++++ .../src/Operators/OperatorRegistration.cpp | 6 ++++++ .../core/providers/dml/OperatorAuthorHelper/Attributes.h | 1 + .../providers/dml/OperatorAuthorHelper/OperatorHelper.h | 1 + .../providers/dml/OperatorAuthorHelper/OperatorVersions.h | 1 + 5 files changed, 16 insertions(+) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp index a014db5adbe61..b243f7e741a70 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp @@ -51,6 +51,12 @@ class DmlOperatorPadding : public DmlOperator, public PaddingHelper { mode = DML_PADDING_MODE_REFLECTION; } +#if DML_TARGET_VERSION >= 0x6300 + else if (modeString == AttrValue::Wrap) + { + mode = DML_PADDING_MODE_WRAP; + } +#endif else { ML_INVALID_ARGUMENT("Unknown Pad mode attribute."); @@ -116,5 +122,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad18, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Pad19, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 18e29c8b99ced..7b53a1102c5a7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -358,6 +358,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Pad7); DML_OP_EXTERN_CREATION_FUNCTION(Pad11); DML_OP_EXTERN_CREATION_FUNCTION(Pad13); DML_OP_EXTERN_CREATION_FUNCTION(Pad18); +DML_OP_EXTERN_CREATION_FUNCTION(Pad19); DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth); DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace); DML_OP_EXTERN_CREATION_FUNCTION(Sqrt); @@ -747,6 +748,11 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 18, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, + +#if DML_TARGET_VERSION >= 0x6300 + {REG_INFO_VER( 19, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, +#endif + {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 13, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index e3df1d00b3e8a..9c5d021f52b36 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -149,5 +149,6 @@ namespace AttrValue static constexpr const char* NearestNeighbor = "NN"; static constexpr const char* NotSet = "NOTSET"; static constexpr const char* Reflect = "reflect"; + static constexpr const char* Wrap = "wrap"; } // namespace AttrValue diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 0d425997e6a6a..d4b44f6fa8a9d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1589,6 +1589,7 @@ using ShapeInferenceHelper_Pad7 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad18 = VersionedOpsetHelper; +using ShapeInferenceHelper_Pad19 = VersionedOpsetHelper; using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 79efc2d2836fe..57cb009b72ebc 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -413,6 +413,7 @@ namespace OperatorHelper namespace OnnxOperatorSet19 { static const int sc_sinceVer_AveragePool = 19; + static const int sc_sinceVer_Pad = 19; static const int sc_sinceVer_Cast = 19; static const int sc_sinceVer_CastLike = 19; static const int sc_sinceVer_Constant = 19; From a33b5bd1fa5ac6d9aabb23cd8aca16b5fc3fc3c5 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Thu, 25 Jan 2024 01:12:21 +0530 Subject: [PATCH 25/45] [JS/WebGPU] Added Uniforms to SkipLayerNorm. (#18788) ### Description Added Uniforms to SkipLayerNorm ### Motivation and Context Improve performance --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 +- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 123 ++++++++++-------- 2 files changed, 69 insertions(+), 58 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index cc504093ca0d7..d737a28654220 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -25,7 +25,7 @@ import * as pool from './ops/pool'; import {range} from './ops/range'; import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; -import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; +import {skipLayerNorm} from './ops/skip-layer-norm'; import {parseSliceAttributes, slice} from './ops/slice'; import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; @@ -116,7 +116,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], ['Slice', [slice, parseSliceAttributes]], - ['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]], + ['SkipLayerNormalization', [skipLayerNorm]], ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], ['Softmax', [softmax, parseSoftmaxAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index a2fda9f07d09f..509a722f4b52a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -4,10 +4,10 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -86,60 +86,74 @@ const createSkipLayerNormProgramInfo = const hasInputSkipBiasSumOutput = outputCount > 3; const components = getMaxComponents(hiddenSize); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), - inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), - ]; - if (hasBetaInput) { - variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); - } - if (hasBiasInput) { - variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - if (hasMeanOutput) { - variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdDevOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInputSkipBiasSumOutput) { - variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); - } - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const hiddenSize: f32 = ${hiddenSize}; - const hiddenSizeVectorized: u32 = ${hiddenSize / components}; - const epsilon: f32 = ${attributes.epsilon}; - ${shaderHelper.declareVariables(...variables)} + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, + {type: 'uint32', data: components}, + {type: 'uint32', data: hiddenSize}, + {type: 'float32', data: attributes.epsilon}, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniformsArray: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'components', type: 'u32'}, + {name: 'hidden_size', type: 'u32'}, + {name: 'epsilon', type: 'f32'}, + ]; + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + + ${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} - let offset = global_idx * hiddenSizeVectorized; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')} + let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components; + let offset = global_idx * hidden_size_vectorized; var sum = ${fillVector('f32', components)}; var squareSum = ${fillVector('f32', components)}; - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - let skipValue = skip[offset + i]; - let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; - let inputValue = x[offset + i]; - let value = inputValue + skipValue + biasValue; - ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + let skip_value = skip[offset + i]; + let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'}; + let input_value = x[offset + i]; + let value = input_value + skip_value + bias_value; + ${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''} output[offset + i] = value; - let f32Value = ${castToF32(dataType, components, 'value')}; - sum += f32Value; - squareSum += f32Value * f32Value; + let f32_value = ${castToF32(dataType, components, 'value')}; + sum += f32_value; + squareSum += f32_value * f32_value; } - let mean = ${sumVector('sum', components)} / hiddenSize; - let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); - ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} - ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''} - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i] - + ${hasBetaInput ? 'beta[i]' : '0.0'}; + let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size); + let inv_std_dev = inverseSqrt(${ + sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon); + ${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''} + ${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${ + hasBetaInput ? 'beta[i]' : '0.0'}; } }`; + }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (outputCount > 1) { outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); @@ -150,12 +164,14 @@ const createSkipLayerNormProgramInfo = if (outputCount > 3) { outputs.push({dims: inputShape, dataType: inputs[0].dataType}); } - return { name: 'SkipLayerNormalization', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: { + hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`, + inputDependencies: inputs.map((_input, _index) => 'type') + }, getShaderSource, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), + getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}), }; }; @@ -178,8 +194,3 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm context.compute( createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs}); }; - -export const parseSkipLayerNormAttributes = (attributes: Record): SkipLayerNormAttributes => { - const epsilon = attributes.epsilon as number; - return createAttributeWithCacheKey({epsilon}); -}; From a28abeb24100441c76a777f9ce225cb0ea3a59c3 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 24 Jan 2024 14:35:44 -0800 Subject: [PATCH 26/45] Change "#ifdef WIN32" to "#ifdef _WIN32" (#19254) ### Description `_WIN32` is a standard macro listed at https://learn.microsoft.com/en-us/cpp/preprocessor/predefined-macros?view=msvc-170 . But `WIN32` is not. --- .../main/native/ai_onnxruntime_OrtSession_SessionOptions.c | 4 ++-- onnxruntime/core/mlas/lib/amx_common.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 3a1c0d1bb8fa1..4a5e2b7ef3b1e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -8,7 +8,7 @@ #include "onnxruntime/core/session/onnxruntime_c_api.h" #include "OrtJniUtil.h" #include "ai_onnxruntime_OrtSession_SessionOptions.h" -#ifdef WIN32 +#ifdef _WIN32 #include #else #include @@ -318,7 +318,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeC // Iterate the handles, calling the appropriate close function for (jint i = 0; i < numHandles; i++) { -#ifdef WIN32 +#ifdef _WIN32 FreeLibrary((void*)handles[i]); #else dlclose((void*)handles[i]); diff --git a/onnxruntime/core/mlas/lib/amx_common.h b/onnxruntime/core/mlas/lib/amx_common.h index 3eb0700932faa..caf94af02362d 100644 --- a/onnxruntime/core/mlas/lib/amx_common.h +++ b/onnxruntime/core/mlas/lib/amx_common.h @@ -18,7 +18,7 @@ Module Name: #include "mlasi.h" -#ifdef WIN32 +#ifdef _WIN32 #define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2) #define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2) From bc54ad3f03d7ee333f5e0c62ebf892c32f8a51a5 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 24 Jan 2024 14:37:39 -0800 Subject: [PATCH 27/45] Update abseil to a release tag and register neural_speed (#19255) ### Description Update abseil to a release tag and register neural_speed to CG. ### Motivation and Context Now we are using a non-relesed version of abseil. Using a tag is better. --- cgmanifests/generated/cgmanifest.json | 12 +++++++++++- cmake/deps.txt | 3 ++- cmake/external/abseil-cpp.cmake | 2 +- cmake/external/abseil-cpp.natvis | 10 +++++----- cmake/external/neural_speed.cmake | 9 +++------ .../azure-pipelines/templates/download-deps.yml | 4 ++-- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index bcd0b2a92a5c3..03e3f84547a68 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f", + "commitHash": "4a2c63365eff8823a5221db86ef490e828306f9d", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" @@ -192,6 +192,16 @@ "comments": "mp11" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a", + "repositoryUrl": "https://github.com/intel/neural-speed.git" + }, + "comments": "neural_speed" + } + }, { "component": { "type": "git", diff --git a/cmake/deps.txt b/cmake/deps.txt index fda27e5e93797..ba9c2bb73cf7a 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -34,6 +34,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 +neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 3bcd4109e2888..57cfbee4644ef 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -19,7 +19,7 @@ if(WIN32 AND NOT Patch_FOUND) set(ABSL_ENABLE_INSTALL ON) endif() # NB! Advancing Abseil version changes its internal namespace, -# currently absl::lts_20230125 which affects abseil-cpp.natvis debugger +# currently absl::lts_20240116 which affects abseil-cpp.natvis debugger # visualization file, that must be adjusted accordingly, unless we eliminate # that namespace at build time. FetchContent_Declare( diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 1e5a36fb9efb9..a4fb63b6a8377 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -1,6 +1,6 @@ - + @@ -24,7 +24,7 @@ - + @@ -51,7 +51,7 @@ - + *($T1 *){value} (*($T1 *){value}) @@ -60,7 +60,7 @@ - + *($T1 *)this (*($T1 *)this) @@ -68,7 +68,7 @@ - + {value.first}, {value.second} ({value.first}, {value.second}) diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake index e66e2acfb209a..ed711351403a7 100644 --- a/cmake/external/neural_speed.cmake +++ b/cmake/external/neural_speed.cmake @@ -7,12 +7,9 @@ endif() if(USE_NEURAL_SPEED) FetchContent_Declare( neural_speed - URL https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip - URL_HASH SHA1=65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 + URL ${DEP_URL_neural_speed} + URL_HASH SHA1=${DEP_SHA1_neural_speed} ) set(BTLA_USE_OPENMP OFF) - FetchContent_MakeAvailable(neural_speed) - if(NOT neural_speed_POPULATED) - FetchContent_Populate(neural_speed) - endif() + onnxruntime_fetchcontent_makeavailable(neural_speed) endif() diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 537175f6bec73..55f6561b7a44a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.129 + version: 1.0.132 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.129 + version: 1.0.132 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 591f90c0b9e8d0922fcebabffed8d54b67d7a613 Mon Sep 17 00:00:00 2001 From: Yang Gu Date: Thu, 25 Jan 2024 06:49:37 +0800 Subject: [PATCH 28/45] [js/webgpu] Fix issue of timestamp query (#19258) When we enable webgpu profiling mode between session.create and session.run, current implementation has a problem to create querySet (and also queryResolveBuffer) if we share the commandEncoder with inputs upload. This PR fixes this by moving the querySet creation to the place we set queryType. --- js/web/lib/wasm/jsep/backend-webgpu.ts | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index afef7042a4280..8ca025d66550c 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -222,16 +222,6 @@ export class WebGpuBackend { getCommandEncoder(): GPUCommandEncoder { if (!this.commandEncoder) { this.commandEncoder = this.device.createCommandEncoder(); - - if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { - this.querySet = this.device.createQuerySet({ - type: 'timestamp', - count: this.maxDispatchNumber * 2, - }); - this.queryResolveBuffer = this.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); - } } return this.commandEncoder; } @@ -654,6 +644,16 @@ export class WebGpuBackend { } else if (this.device.features.has('timestamp-query')) { this.queryType = 'at-passes'; } + + if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.maxDispatchNumber * 2, + }); + this.queryResolveBuffer = this.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); + } } } onRunStart(): void { From c456f19dbaf6b23928a60e8b356a429ae76376a4 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Wed, 24 Jan 2024 15:20:36 -0800 Subject: [PATCH 29/45] remove old quantization tool file (#19247) ### Description remove old python files ### Motivation and Context We have a new op MatMulNBits and this one is deprecated. --- .../python/tools/quantization/__init__.py | 1 - .../quantization/matmul_weight4_quantizer.py | 260 ------------------ .../python/quantization/test_op_matmulfpq4.py | 153 ----------- 3 files changed, 414 deletions(-) delete mode 100644 onnxruntime/python/tools/quantization/matmul_weight4_quantizer.py delete mode 100644 onnxruntime/test/python/quantization/test_op_matmulfpq4.py diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 170c0928fee23..9d397499d45a4 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -5,7 +5,6 @@ MinMaxCalibrater, create_calibrator, ) -from .matmul_weight4_quantizer import MatMulWeight4Quantizer # noqa: F401 from .qdq_quantizer import QDQQuantizer # noqa: F401 from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401 from .quantize import DynamicQuantConfig # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/matmul_weight4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_weight4_quantizer.py deleted file mode 100644 index 921e02fb69e9b..0000000000000 --- a/onnxruntime/python/tools/quantization/matmul_weight4_quantizer.py +++ /dev/null @@ -1,260 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import argparse -import struct -from pathlib import Path -from typing import List, Tuple - -import numpy as np -import numpy.typing as npt -import onnx -from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto - -from .onnx_model import ONNXModel -from .quant_utils import attribute_to_kwarg, load_model_with_shape_infer - - -def __q4_block_size(quant_type: int) -> int: - # happens to be 32 for now, but future quantization types - # may have bigger block size - return 32 - - -def __q4_blob_size(quant_type: int) -> int: - if quant_type == MatMulWeight4Quantizer.BlkQ4Sym: - # 4b each value, with one fp32 scale - blob_size = 32 // 2 + 4 - elif quant_type == MatMulWeight4Quantizer.BlkQ4Zp8: - # 4b each value, with one fp32 scale and one uint8 zero point - blob_size = 32 // 2 + 4 + 1 - else: - raise ValueError(f"Unsupported quantization type: {quant_type}") - return blob_size - - -def __q4_buf_size(quant_type: int, rows: int, cols: int) -> int: - block_size = __q4_block_size(quant_type) - blob_size = __q4_blob_size(quant_type) - k_blocks = (rows + block_size - 1) // block_size - return k_blocks * cols * blob_size - - -def int4_block_quant(quant_type: int, fp32weight: npt.ArrayLike) -> np.ndarray: - """4b quantize fp32 weight to a blob""" - - if len(fp32weight.shape) != 2: - raise ValueError("Current int4 block quantization only supports 2D tensors!") - rows, cols = fp32weight.shape - - block_size = __q4_block_size(quant_type) - blob_size = __q4_blob_size(quant_type) - k_blocks = (rows + block_size - 1) // block_size - padded_rows = k_blocks * block_size - pad_len = padded_rows - rows - if pad_len > 0: - fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") - - # block wise quantization, each block comes from a single column - blob_idx = 0 - packed = np.zeros((cols * k_blocks, blob_size), dtype="uint8") - for n in range(cols): - ncol = fp32weight[:, n] - blks = np.split(ncol, k_blocks) - for blk in blks: - packed_blob = packed[blob_idx] - blob_idx += 1 - - if quant_type == MatMulWeight4Quantizer.BlkQ4Sym: - amax_idx = np.argmax(np.abs(blk)) - bmax = blk[amax_idx] - scale = bmax / (-8) - zp = 8 - else: - vmin = np.min(blk) - vmax = np.max(blk) - vmin = min(vmin, 0.0) - vmax = max(vmax, 0.0) - scale = (vmax - vmin) / ((1 << 4) - 1) - zero_point_fp = vmin - if scale != 0.0: - zero_point_fp = 0.0 - vmin / scale - zp = min(15, max(0, round(zero_point_fp))) - - reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 - bf = struct.pack("f", scale) - packed_blob[0] = bf[0] - packed_blob[1] = bf[1] - packed_blob[2] = bf[2] - packed_blob[3] = bf[3] - blob_offset = 4 - if quant_type == MatMulWeight4Quantizer.BlkQ4Zp8: - packed_blob[4] = zp - blob_offset = 5 - - num_segs = block_size // 32 - blk_int = np.clip(np.rint(blk * reciprocal_scale + zp), 0, 15).astype("uint8") - segs = np.split(blk_int, num_segs) - for seg in segs: - packed_blob[blob_offset : (blob_offset + 16)] = np.bitwise_or(seg[0:16], np.left_shift(seg[16:32], 4)) - blob_offset += 16 - return packed.reshape(-1) - - -class MatMulWeight4Quantizer: - """Perform 4b quantization of constant MatMul weights""" - - ################## - # quantization types, must be consistent with native code type - # MLAS_BLK_QUANT_TYPE defined in mlas_q4.h - - # 32 number block, symmetric quantization, with one fp32 as scale, zero point is always 0 - BlkQ4Sym = 0 - - # 32 number block, quantization, with one fp32 as scale, one uint8 zero point - BlkQ4Zp8 = 1 - - def __init__(self, model: ModelProto, quant_type: int): - self.model = ONNXModel(model) - self.quant_type = quant_type - - @staticmethod - def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: - for gid in range(len(graph_path) - 1, -1, -1): - graph = graph_path[gid] - for tensor in graph.initializer: - if tensor.name == name: - return tensor, graph - return None, None - - def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: - """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" - - if node.op_type != "MatMul": - return node # only care about MatMul for now - - inputB = node.input[1] # noqa: N806 - B, Bs_graph = MatMulWeight4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806 - if B is None: - return node # only care about constant weight - - # TODO!! assume B is not used by any other node - B_array = onnx.numpy_helper.to_array(B) # noqa: N806 - if len(B_array.shape) != 2: - return node # can only process 2-D matrix - - rows, cols = B_array.shape - packed = int4_block_quant(self.quant_type, B_array) - B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 - B_quant.name = B.name + "_Q4" - Bs_graph.initializer.remove(B) - for input in Bs_graph.input: - if input.name == inputB: - Bs_graph.input.remove(input) - break - - B_shape = onnx.numpy_helper.from_array(np.array([rows, cols]).astype(np.int64)) # noqa: N806 - B_shape.name = B.name + "_shape" - Bs_graph.initializer.extend([B_quant, B_shape]) - - kwargs = {} - kwargs["blk_quant_type"] = self.quant_type - matmul_q4_node = onnx.helper.make_node( - "MatMulFpQ4", - inputs=[node.input[0], B_quant.name, B_shape.name], - outputs=[node.output[0]], - name=node.name + "_Q4" if node.name else "", - domain="com.microsoft", - **kwargs, - ) - return matmul_q4_node - - def _process_subgraph(self, graph_stack: List[GraphProto]): - new_nodes = [] - graph = graph_stack[-1] - - for node in graph.node: - graph_attrs = [ - attr - for attr in node.attribute - if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS - ] - if len(graph_attrs): - kwargs = {} - for attr in node.attribute: - if attr.type == onnx.AttributeProto.GRAPH: - # recursive call to take care of sub-graph - graph_stack.append(attr.g) - kv = {attr.name: self._process_subgraph(graph_stack)} - elif attr.type == onnx.AttributeProto.GRAPHS: - value = [] - for subgraph in attr.graphs: - # recursive call to take care of sub-graph - graph_stack.append(subgraph) - value.extend([self._process_subgraph(graph_stack)]) - kv = {attr.name: value} - else: - kv = attribute_to_kwarg(attr) - kwargs.update(kv) - node = onnx.helper.make_node( # noqa: PLW2901 - node.op_type, node.input, node.output, name=node.name, **kwargs - ) - - new_nodes.append(self._q4_matmul_node_weight(node, graph_stack)) - - graph.ClearField("node") - graph.node.extend(new_nodes) - graph_stack.pop() - return graph - - def process(self): - # use a stack to keep track of sub-graphs - graph_stack = [self.model.graph()] - opset_import = self.model.opset_import() - - has_ms_domain = False - for opset in opset_import: - if opset.domain == "com.microsoft": - has_ms_domain = True - if not has_ms_domain: - opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - - self._process_subgraph(graph_stack) - - -def parse_args(): - parser = argparse.ArgumentParser( - description="""Blockwise int4 quantization for MatMul 2D weight matrices. - -A weight matrix is partitioned into into blocks, where each block is a -continguous subset inside each column. Each block is quantized into a -set of 4b integers with a scaling factor and an optional offset. -""" - ) - - parser.add_argument("--input_model", required=True, help="Path to the input model file") - parser.add_argument("--output_model", required=True, help="Path to the output model file") - parser.add_argument( - "--quant_bin_path", - required=True, - help="""Currently quantization code is implemented in a separate binary -(onnxruntime_mlas_q4dq) that is compiled with Onnxruntime native code. -Path to this binary needs to be provided here.""", - ) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - - input_model_path = args.input_model - output_model_path = args.output_model - q4dq_bin_path = args.quant_bin_path - - model = load_model_with_shape_infer(Path(input_model_path)) - quant = MatMulWeight4Quantizer(model, 0) - quant.process() - quant.model.save_model_to_file(output_model_path, False) diff --git a/onnxruntime/test/python/quantization/test_op_matmulfpq4.py b/onnxruntime/test/python/quantization/test_op_matmulfpq4.py deleted file mode 100644 index 170bb09a0fdeb..0000000000000 --- a/onnxruntime/test/python/quantization/test_op_matmulfpq4.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import tempfile -import unittest -from pathlib import Path -from typing import Dict, Tuple, Union - -import numpy as np -import onnx -from onnx import TensorProto, helper -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count - -from onnxruntime.quantization import MatMulWeight4Quantizer, quant_utils - - -class TestOpMatMulFpQ4(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmulfpq4.") - - @classmethod - def tearDownClass(cls): - cls._tmp_model_dir.cleanup() - - def fill_int4_data(self, shape: Union[int, Tuple[int, ...]], symmetric: bool) -> np.ndarray: - line = np.zeros(shape) - line = line.reshape(-1) - - if symmetric: - v = -2.0 - for i in range(line.shape[0]): - if v == 0 or v == -3 or v == 3: - v += 1 - line[i] = v - v += 1 - if v >= 8: - v = -8 - else: - v = 0.0 - for i in range(line.shape[0]): - line[i] = v - v += 1 - if v >= 16: - v = 0 - - return line.reshape(shape) - - def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds: - input_data_list = [] - for _i in range(n): - inputs = {} - for name, shape in name2shape.items(): - inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) - input_data_list.extend([inputs]) - dr = TestDataFeeds(input_data_list) - return dr - - def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> None: - # (input) - # | - # MatMul - # | - # (output) - input_name = "input" - output_name = "output" - initializers = [] - - def make_gemm(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): - weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) - initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) - return onnx.helper.make_node( - "MatMul", - [input_name, weight_name], - [output_name], - ) - - in_features = 52 - out_features = 288 - # make MatMulFpQ4 node - matmul_node = make_gemm( - input_name, - [in_features, out_features], - "linear1.weight", - output_name, - ) - - # make graph - input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features]) - output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features]) - graph_name = "matmul_test" - graph = helper.make_graph( - [matmul_node], - graph_name, - [input_tensor], - [output_tensor], - initializer=initializers, - ) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version - - onnx.save(model, output_model_path) - - def quant_test( - self, - model_fp32_path: str, - data_reader: TestDataFeeds, - quantization_type: int, # 0: BlkQ4Sym, 1: BlkQ4Zp8 - ): - qtype_str = "BlkQ4Sym" if (quantization_type == 0) else "BlkQ4Zp8" - model_int4_path = str(Path(self._tmp_model_dir.name).joinpath(f"matmulfpq4_{qtype_str}.onnx").absolute()) - - # Quantize fp32 model to int4 model - model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) - quant = MatMulWeight4Quantizer(model, quantization_type) - quant.process() - quant.model.save_model_to_file(model_int4_path, False) - - quant_nodes = {"MatMulFpQ4": 1} - check_op_type_count(self, model_int4_path, **quant_nodes) - - data_reader.rewind() - - try: - check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next()) - except Exception as exception: - if "4b quantization not yet supported on this hardware platform!" in exception.args[0]: - # Currently we don't have int4 quantization support on all platforms, has to tolerate this exception - pass - else: - raise exception - - def test_quantize_matmul_int4_symmetric(self): - np.random.seed(13) - - model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=True) - data_reader = self.input_feeds(1, {"input": [100, 52]}) - self.quant_test(model_fp32_path, data_reader, quantization_type=MatMulWeight4Quantizer.BlkQ4Sym) - - def test_quantize_matmul_int4_offsets(self): - model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) - self.construct_model_matmul(model_fp32_path, symmetric=False) - data_reader = self.input_feeds(1, {"input": [100, 52]}) - self.quant_test(model_fp32_path, data_reader, quantization_type=MatMulWeight4Quantizer.BlkQ4Zp8) - - -if __name__ == "__main__": - unittest.main() From 7252c6e747de83b65285601281a9d07aea801fba Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 25 Jan 2024 07:37:35 +0800 Subject: [PATCH 30/45] [WebNN EP] Support WebNN async API with Asyncify (#19145) --- js/web/lib/build-def.d.ts | 4 --- js/web/lib/index.ts | 4 +-- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 4 +-- js/web/script/build.ts | 7 +--- js/web/script/test-runner-cli-args.ts | 4 --- .../core/providers/webnn/builders/model.cc | 35 ++++++++----------- .../providers/webnn/builders/model_builder.cc | 12 +++---- .../webnn/webnn_execution_provider.cc | 3 +- onnxruntime/wasm/js_internal_api.js | 4 +++ 10 files changed, 30 insertions(+), 49 deletions(-) diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index b3868871a4753..2c9cd88a375bd 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -21,10 +21,6 @@ interface BuildDefinitions { /** * defines whether to disable the whole WebNN backend in the build. */ - readonly DISABLE_WEBNN: boolean; - /** - * defines whether to disable the whole WebAssembly backend in the build. - */ readonly DISABLE_WASM: boolean; /** * defines whether to disable proxy feature in WebAssembly backend in the build. diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index baf45e74addea..b212c0f49df3b 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) { require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU) { registerBackend('webgpu', wasmBackend, 5); + registerBackend('webnn', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (!BUILD_DEFS.DISABLE_WEBNN) { - registerBackend('webnn', wasmBackend, 9); - } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 68054210e79a7..24d7062c85fcb 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; - _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number; + _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; _OrtReleaseSession(sessionHandle: number): void; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 8768643fa7257..046336dc9cac0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise => { * @param epName */ export const initEp = async(env: Env, epName: string): Promise => { - if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { + if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) { // perform WebGPU availability check if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); @@ -228,7 +228,7 @@ export const createSession = async( await Promise.all(loadingPromises); } - sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } diff --git a/js/web/script/build.ts b/js/web/script/build.ts index ea0c122cb51de..d3652f3820357 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // /js/ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WEBGL': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'false', - 'BUILD_DEFS.DISABLE_WEBNN': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'false', @@ -364,7 +363,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -397,7 +395,7 @@ async function main() { // ort.webgpu[.min].js await addAllWebBuildTasks({ outputBundleName: 'ort.webgpu', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'}, + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, }); // ort.wasm[.min].js await addAllWebBuildTasks({ @@ -411,7 +409,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WASM': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); // ort.wasm-core[.min].js @@ -421,7 +418,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -434,7 +430,6 @@ async function main() { 'BUILD_DEFS.DISABLE_TRAINING': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 8f6c5f6f04122..ed4dd76a6e315 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const globalEnvFlags = parseGlobalEnvFlags(args); - if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) { - throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.'); - } - // Options: // --log-verbose=<...> // --log-info=<...> diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index eaf549ef4e072..ef807a8c4fa26 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -70,22 +70,13 @@ Status Model::Predict(const InlinedHashMap& inputs, "The input of graph has unsupported type, name: ", name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers. + // Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer. + // As Wasm ArrayBuffer is not detachable. wnn_inputs_[name].call("set", view); -#else - wnn_inputs_.set(name, view); -#endif } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // This vector uses for recording output buffers from WebNN graph compution when WebAssembly - // multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView, - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews - // and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory - // address is different from the non-shared one, additional memory copy is required here. InlinedHashMap output_views; -#endif + for (const auto& output : outputs) { const std::string& name = output.first; const struct OnnxTensorData tensor = output.second; @@ -131,21 +122,23 @@ Status Model::Predict(const InlinedHashMap& inputs, name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS output_views.insert({name, view}); -#else - wnn_outputs_.set(name, view); -#endif } - wnn_context_.call("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer. + emscripten::val results = wnn_context_.call( + "compute", wnn_graph_, wnn_inputs_, wnn_outputs_) + .await(); + + // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer. for (const auto& output : outputs) { const std::string& name = output.first; emscripten::val view = output_views.at(name); - view.call("set", wnn_outputs_[name]); + view.call("set", results["outputs"][name]); } -#endif + // WebNN compute() method would return the input and output buffers via the promise + // resolution. Reuse the buffers to avoid additional allocation. + wnn_inputs_ = results["inputs"]; + wnn_outputs_ = results["outputs"]; + return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index cf8a0e23db43b..56f7ead8ccf5d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -386,7 +386,8 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { for (auto& name : output_names_) { named_operands.set(name, wnn_operands_.at(name)); } - emscripten::val wnn_graph = wnn_builder_.call("buildSync", named_operands); + + emscripten::val wnn_graph = wnn_builder_.call("build", named_operands).await(); if (!wnn_graph.as()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); } @@ -395,13 +396,10 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { model->SetOutputs(std::move(output_names_)); model->SetScalarOutputs(std::move(scalar_outputs_)); model->SetInputOutputInfo(std::move(input_output_info_)); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Pre-allocate the input and output tensors for the WebNN graph - // when WebAssembly multi-threads is enabled since WebNN API only - // accepts non-shared ArrayBufferView. - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + // Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews + // for inputs and outputs because they will be transferred after compute() done. + // https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution model->AllocateInputOutputBuffers(); -#endif return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 2922cf9540a8e..df7871614b267 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -42,7 +42,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (webnn_power_flags.compare("default") != 0) { context_options.set("powerPreference", emscripten::val(webnn_power_flags)); } - wnn_context_ = ml.call("createContextSync", context_options); + + wnn_context_ = ml.call("createContext", context_options).await(); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 7c70515e73eab..7e9c0a6f99c32 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -160,6 +160,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea }; // replace the original functions with asyncified versions + Module['_OrtCreateSession'] = jsepWrapAsync( + Module['_OrtCreateSession'], + () => Module['_OrtCreateSession'], + v => Module['_OrtCreateSession'] = v); Module['_OrtRun'] = runAsync(jsepWrapAsync( Module['_OrtRun'], () => Module['_OrtRun'], From 0c2f0ba90da11ad53c63810e5f3e6fda4e295899 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 25 Jan 2024 07:53:10 +0800 Subject: [PATCH 31/45] [WebNN EP] Support conv1d by reshaping with prepended 1's (#18857) WebNN only supports 4-D inputs for conv2d and convTranspose2d, this PR supports 3-D inputs (i.e. conv1d) by prepending a 1 size dimension and several reshape operations. --- .../core/providers/webnn/builders/helper.h | 9 + .../webnn/builders/impl/conv_op_builder.cc | 221 +++++++++++------- 2 files changed, 141 insertions(+), 89 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 85dafcaf66575..92aa9abc9fdf7 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -54,6 +54,15 @@ std::string GetShapeString(std::vector& shape) { return shape_info.str(); } +inline std::vector GetVecUint32FromVecInt64(const std::vector& int64_vec) { + std::vector uint32_vec; + uint32_vec.reserve(int64_vec.size()); + std::transform(int64_vec.begin(), int64_vec.end(), + std::back_inserter(uint32_vec), + [](int64_t val) -> uint32_t { return SafeInt(val); }); + return uint32_vec; +} + template bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector& array, const logging::Logger& logger) { std::vector unpacked_tensor; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index ceacb7c2b38a3..c74545479e466 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -42,72 +42,61 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod // Helper functions common::Status SetConvBaseOptions(ModelBuilder& model_builder, const Node& node, emscripten::val& options, - const std::vector& strides, - const std::vector& dilations, - std::vector& pads, + const std::vector input_shape, + const std::vector weight_shape, + const std::vector& strides, + const std::vector& dilations, + std::vector& pads, + const bool is_nhwc, + const bool is_conv1d, const logging::Logger& logger) { NodeAttrHelper helper(node); - const auto group = helper.Get("group", static_cast(1)); const auto& input_defs = node.InputDefs(); - std::vector weight_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[1], weight_shape, logger), "Cannot get weight shape"); - options.set("strides", emscripten::val::array(strides)); - options.set("dilations", emscripten::val::array(dilations)); - options.set("groups", group); + // Add Padding. - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); if (node.OpType() == "Conv") { // Calculate explicit padding for autoPad. if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { std::vector pads_out; ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - helper.Get("pads", std::vector{0, 0, 0, 0}), - helper.Get("strides", std::vector{1, 1}), - helper.Get("dilations", std::vector{1, 1}), - auto_pad_type, - pads_out, - model_builder.GetPreferredLayout() == DataLayout::NCHW)); - std::transform(pads_out.begin(), pads_out.end(), pads.begin(), - [](int64_t pad) -> int32_t { return static_cast(pad); }); + pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc)); + pads = pads_out; } } else if (node.OpType() == "ConvTranspose") { // When the 'output_shape' is specificed, the 'output_padding' values // in options.outputPadding are ignored. - std::vector dim; - std::vector output_padding{0, 0}; + std::vector dims; + std::vector output_padding{0, 0}; if (helper.HasAttr("output_shape")) { - // Default value of 'output_shape' will be ignore as we already check if - // it's existed. - dim = helper.Get("output_shape", std::vector{-1, -1}); + // Default value of 'output_shape' will be ignored as we already check if it existed. + dims = helper.Get("output_shape", std::vector{-1, -1}); // Extract the height and width. - std::vector output_shape; - if (dim.size() == 2) { - output_shape = dim; - } else if (dim.size() == 4) { - output_shape = {dim[2], dim[3]}; + std::vector output_shape; + if (dims.size() == 1 && is_conv1d) { // ConvTranspose 1d + output_shape = {dims[0], 1}; + } else if (dims.size() == 2 && !is_conv1d) { + output_shape = dims; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); } // Padding values are auto generated. if (helper.HasAttr("kernel_shape")) { - std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); - std::vector total_padding(2); - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); + if (is_conv1d) { // ConvTranspose 1d + kernel_shape.push_back(1); + } + std::vector total_padding(2); for (size_t i = 0; i < 2; i++) { // Get the dimensions of H and W. // For NHWC layout, the dimensions of H and W correspond to index 1 and 2. // For NCHW layout, the dimensions of H and W correspond to index 2 and 3. - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { - total_padding[i] = strides[i] * (narrow(input_shape[i + 1]) - 1) + - output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + if (is_nhwc) { + total_padding[i] = strides[i] * (input_shape[i + 1] - 1) + output_padding[i] + + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; } else { - ORT_RETURN_IF_NOT(model_builder.GetPreferredLayout() == DataLayout::NCHW, - "WebNN GPU backend preferred layout should be NCHW."); - total_padding[i] = strides[i] * (narrow(input_shape[i + 2]) - 1) + - output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + total_padding[i] = strides[i] * (input_shape[i + 2] - 1) + output_padding[i] + + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; } } AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); @@ -122,18 +111,27 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, } } } - options.set("outputSizes", emscripten::val::array(output_shape)); + options.set("outputSizes", emscripten::val::array(GetVecUint32FromVecInt64(output_shape))); } else { - output_padding = helper.Get("output_padding", std::vector{0, 0}); - options.set("outputPadding", emscripten::val::array(output_padding)); + output_padding = helper.Get("output_padding", std::vector{0, 0}); + if (output_padding.size() == 1 && is_conv1d) { // ConvTranspose 1d + output_padding.push_back(0); + } + options.set("outputPadding", emscripten::val::array(GetVecUint32FromVecInt64(output_padding))); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "conv_op_builder only supports Op Conv and ConvTranspose."); } + + const auto group = helper.Get("group", static_cast(1)); + options.set("groups", group); + options.set("strides", emscripten::val::array(GetVecUint32FromVecInt64(strides))); + options.set("dilations", emscripten::val::array(GetVecUint32FromVecInt64(dilations))); + // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. - const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; - options.set("padding", emscripten::val::array(padding)); + const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; + options.set("padding", emscripten::val::array(GetVecUint32FromVecInt64(padding))); // Add bias if present. if (input_defs.size() > 2) { @@ -151,7 +149,8 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, // Both depthwise Conv and ConvTranspose share the same logic to add the layout. Status AddInitializerInNewLayout(ModelBuilder& model_builder, const std::string& name, - bool is_conv) { + bool is_conv, + bool is_conv1d) { const auto& tensor = *model_builder.GetInitializerTensors().at(name); auto data_type = tensor.data_type(); if (!IsSupportedDataType(data_type, model_builder.GetWebnnDeviceType())) { @@ -161,13 +160,13 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, } const auto& shape = tensor.dims(); - std::vector dims; - std::transform(shape.cbegin(), shape.cend(), - std::back_inserter(dims), - [](int64_t dim) -> int32_t { return SafeInt(dim); }); + std::vector dims = GetVecUint32FromVecInt64(std::vector(std::begin(shape), std::end(shape))); + + if (is_conv1d) { + // Support conv1d by prepending a 1 size dimension. + dims.push_back(1); + } - ORT_RETURN_IF_NOT(dims.size() == 4, - "The initializer is not 4D: ", name, " actual dim ", dims.size()); const uint8_t* src = nullptr; Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath()); src = unpacked_tensor.DataAsByteSpan().data(); @@ -257,57 +256,101 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val output = emscripten::val::object(); - NodeAttrHelper helper(node); - const auto strides = helper.Get("strides", std::vector{1, 1}); - const auto dilations = helper.Get("dilations", std::vector{1, 1}); - auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); + std::vector weight_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], weight_shape, logger), "Cannot get weight shape"); const auto& weight_name = input_defs[1]->Name(); + + NodeAttrHelper helper(node); + auto strides = helper.Get("strides", std::vector{1, 1}); + auto dilations = helper.Get("dilations", std::vector{1, 1}); + auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + + const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; + const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3; + // Support conv1d by prepending a 1 or 2 size dimensions. + if (is_conv1d) { + // Reshape input. + if (is_nhwc) { + // For NHWC preferred layout, the input has been transposed. + // For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2. + input_shape.insert(input_shape.begin() + 2, 1); + } else { + input_shape.push_back(1); + } + std::vector new_shape = GetVecUint32FromVecInt64(input_shape); + input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + + weight_shape.resize(4, 1); // Ensure 4D by appending 1's if needed. + strides.resize(2, 1); // Ensure 2D by appending 1's if needed. + dilations.resize(2, 1); // Ensure 2D by appending 1's if needed. + if (pads.size() == 2) { + pads.insert(pads.begin() + 1, 0); + pads.push_back(0); + } + } + emscripten::val options = emscripten::val::object(); - ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); + ORT_RETURN_IF_ERROR(SetConvBaseOptions( + model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); if (op_type == "Conv" || op_type == "ConvInteger") { int groups = options["groups"].as(); - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + if (is_nhwc) { bool depthwise = (groups == input_shape[3] && groups != 1); options.set("inputLayout", emscripten::val("nhwc")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d)); if (!depthwise) { options.set("filterLayout", emscripten::val("ohwi")); } else { options.set("filterLayout", emscripten::val("ihwo")); } } - emscripten::val filter = model_builder.GetOperand(weight_name); - if (op_type == "Conv") { - output = model_builder.GetBuilder().call("conv2d", input, filter, options); - } else { - emscripten::val x_zero_point = emscripten::val::null(); - emscripten::val w_zero_point = emscripten::val::null(); - if (input_defs.size() >= 3) { - x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); - } else { - x_zero_point = model_builder.GetZeroConstant("uint8"); - } - if (input_defs.size() >= 4) { - w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); - } else { - w_zero_point = model_builder.GetZeroConstant("uint8"); - } - output = model_builder.GetBuilder().call("conv2dInteger", - input, x_zero_point, filter, w_zero_point, options); - } - - } else { - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + } else { // ConvTranspose + if (is_nhwc) { options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false, is_conv1d)); } - emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name()); + } + + emscripten::val filter = model_builder.GetOperand(weight_name); + if (!is_nhwc && is_conv1d) { + // Reshape weight to 4D for conv1d with NCHW preferred layout. + std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); + filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); + } + + if (op_type == "Conv") { + output = model_builder.GetBuilder().call("conv2d", input, filter, options); + } else if (op_type == "ConvInteger") { + emscripten::val x_zero_point = emscripten::val::null(); + emscripten::val w_zero_point = emscripten::val::null(); + if (input_defs.size() >= 3) { + x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); + } else { + x_zero_point = model_builder.GetZeroConstant("uint8"); + } + if (input_defs.size() >= 4) { + w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); + } else { + w_zero_point = model_builder.GetZeroConstant("uint8"); + } + output = model_builder.GetBuilder().call("conv2dInteger", + input, x_zero_point, filter, w_zero_point, options); + } else { output = model_builder.GetBuilder().call("convTranspose2d", input, filter, options); } + // If it's a conv1d, reshape it back. + if (is_conv1d) { + const auto& output_defs = node.OutputDefs(); + std::vector output_shape; + ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape"); + std::vector new_shape = GetVecUint32FromVecInt64(output_shape); + output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); + } + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); } @@ -329,9 +372,9 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } const auto input_size = input_shape.size(); - if (input_size != 4) { + if (input_size != 4 && input_size != 3) { LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s input dimension: " << input_size - << ". Only conv 2d is supported."; + << ". Only conv 1d / 2d is supported."; return false; } @@ -342,9 +385,9 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } const auto weight_size = weight_shape.size(); - if (weight_size != 4) { + if (weight_size != 4 && weight_size != 3) { LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s weight dimension: " << weight_size - << ". Only conv 2d is supported."; + << ". Only conv 1d / 2d is supported."; return false; } From 4477f57ee3151287a9759bd09d269f0e258a9eda Mon Sep 17 00:00:00 2001 From: Phoebe Chen Date: Thu, 25 Jan 2024 08:27:05 +0800 Subject: [PATCH 32/45] Enable RISC-V 64-bit Cross-Compiling Support for ONNX Runtime on Linux (#19238) ### Description This pull request introduces the necessary changes to enable RISC-V 64-bit cross-compiling support for the ONNX Runtime on Linux. The RISC-V architecture has gained popularity as an open standard instruction set architecture, and this contribution aims to extend ONNX Runtime's compatibility to include RISC-V, thereby broadening the reach of ONNX models to a wider range of devices. ### Motivation and Context RISC-V is a free and open-source instruction set architecture (ISA) based on established RISC principles. It is provided under open licenses without fees. Due to its extensibility and freedom in both software and hardware, RISC-V is poised for widespread adoption in the future, especially in applications related to AI, parallel computing, and data centers. ### Example Build Command ``` ./build.sh --parallel --config Debug --rv64 --riscv_toolchain_root=/path/to/toolchain/root --skip_tests ``` ### Documentation Updates Relevant sections of the documentation will be updated to reflect the newly supported RISC-V 64-bit cross-compilation feature. https://github.com/microsoft/onnxruntime/pull/19239 --------- Signed-off-by: Phoebe Chen --- cmake/external/xnnpack.cmake | 6 +- cmake/onnxruntime_common.cmake | 4 +- cmake/riscv64.toolchain.cmake | 35 +++++++++ tools/ci_build/build.py | 35 ++++++++- tools/scripts/build_riscv64.sh | 129 +++++++++++++++++++++++++++++++++ 5 files changed, 206 insertions(+), 3 deletions(-) create mode 100644 cmake/riscv64.toolchain.cmake create mode 100755 tools/scripts/build_riscv64.sh diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index e661aa51bfc17..41f02ce6f22bc 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -6,10 +6,14 @@ set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_TESTS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE INTERNAL "") +if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(XNNPACK_USE_SYSTEM_LIBS OFF) +endif() + # BF16 instructions cause ICE in Android NDK compiler if(CMAKE_ANDROID_ARCH_ABI STREQUAL armeabi-v7a) set(XNNPACK_ENABLE_ARM_BF16 OFF) -ENDIF() +endif() # fp16 depends on psimd FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 43d5fa9bdee34..6b8c2560b1714 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -189,6 +189,8 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set(ARM TRUE) elseif(dumpmachine_output MATCHES "^aarch64.*") set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(RISCV64 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") set(X86 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") @@ -198,7 +200,7 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() -if (ARM64 OR ARM OR X86 OR X64 OR X86_64) +if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC)) # msvc compiler report syntax error with cpuinfo arm source files # and cpuinfo does not have code for getting arm uarch info under windows diff --git a/cmake/riscv64.toolchain.cmake b/cmake/riscv64.toolchain.cmake new file mode 100644 index 0000000000000..0fda239f9a628 --- /dev/null +++ b/cmake/riscv64.toolchain.cmake @@ -0,0 +1,35 @@ +# Copyright (c) 2024 SiFive, Inc. All rights reserved. +# Copyright (c) 2024, Phoebe Chen +# Licensed under the MIT License. + +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +list(APPEND CMAKE_TRY_COMPILE_PLATFORM_VARIABLES RISCV_TOOLCHAIN_ROOT) + +if(NOT RISCV_TOOLCHAIN_ROOT) + message(FATAL_ERROR "RISCV_TOOLCHAIN_ROOT is not defined. Please set the RISCV_TOOLCHAIN_ROOT variable.") +endif() + +set(CMAKE_C_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_ASM_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-g++") + +set(CMAKE_FIND_ROOT_PATH ${RISCV_TOOLCHAIN_ROOT}) +set(CMAKE_SYSROOT "${RISCV_TOOLCHAIN_ROOT}/sysroot") +set(CMAKE_INCLUDE_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/include/") +set(CMAKE_LIBRARY_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/lib/") +set(CMAKE_PROGRAM_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/bin/") + +if(RISCV_QEMU_PATH) + message(STATUS "RISCV_QEMU_PATH=${RISCV_QEMU_PATH} is defined during compilation.") + set(CMAKE_CROSSCOMPILING_EMULATOR "${RISCV_QEMU_PATH};-L;${CMAKE_SYSROOT}") +endif() + +set(CMAKE_CROSSCOMPILING TRUE) + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 6e5cd7b57e403..186bb699ad209 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -328,6 +328,12 @@ def convert_arg_line_to_args(self, arg_line): help="[cross-compiling] Create Windows x86 makefiles. Requires --update and no existing cache " "CMake setup. Delete CMakeCache.txt if needed", ) + parser.add_argument( + "--rv64", + action="store_true", + help="[cross-compiling] Create riscv64 makefiles. Requires --update and no existing cache " + "CMake setup. Delete CMakeCache.txt if needed", + ) parser.add_argument( "--arm", action="store_true", @@ -351,6 +357,18 @@ def convert_arg_line_to_args(self, arg_line): action="store_true", help="[cross-compiling] Create ARM64X Binary.", ) + parser.add_argument( + "--riscv_toolchain_root", + type=str, + default="", + help="Path to RISC-V toolchain root dir. e.g. --riscv_toolchain_root=$HOME/riscv-tools/", + ) + parser.add_argument( + "--riscv_qemu_path", + type=str, + default="", + help="Path to RISC-V qemu. e.g. --riscv_qemu_path=$HOME/qemu-dir/qemu-riscv64", + ) parser.add_argument("--msvc_toolset", help="MSVC toolset to use. e.g. 14.11") parser.add_argument("--windows_sdk_version", help="Windows SDK version to use. e.g. 10.0.19041.0") parser.add_argument("--android", action="store_true", help="Build for Android") @@ -1077,6 +1095,19 @@ def generate_build_tree( "-Donnxruntime_DISABLE_OPTIONAL_TYPE=" + ("ON" if disable_optional_type else "OFF"), ] + if args.rv64: + add_default_definition(cmake_extra_defines, "onnxruntime_CROSS_COMPILING", "ON") + if not args.riscv_toolchain_root: + raise BuildError("The --riscv_toolchain_root option is required to build for riscv64.") + if not args.skip_tests and not args.riscv_qemu_path: + raise BuildError("The --riscv_qemu_path option is required for testing riscv64.") + + cmake_args += [ + "-DRISCV_TOOLCHAIN_ROOT:PATH=" + args.riscv_toolchain_root, + "-DRISCV_QEMU_PATH:PATH=" + args.riscv_qemu_path, + "-DCMAKE_TOOLCHAIN_FILE=" + os.path.join(source_dir, "cmake", "riscv64.toolchain.cmake"), + ] + # By default on Windows we currently support only cross compiling for ARM/ARM64 # (no native compilation supported through this script). if args.arm64 or args.arm64ec or args.arm: @@ -1553,7 +1584,9 @@ def generate_build_tree( ] if is_linux() and platform.machine() == "x86_64": # The following flags needs GCC 8 and newer - cflags += ["-fstack-clash-protection", "-fcf-protection"] + cflags += ["-fstack-clash-protection"] + if not args.rv64: + cflags += ["-fcf-protection"] cxxflags = cflags.copy() if args.use_cuda: cudaflags = cflags.copy() diff --git a/tools/scripts/build_riscv64.sh b/tools/scripts/build_riscv64.sh new file mode 100755 index 0000000000000..65681c0b6307d --- /dev/null +++ b/tools/scripts/build_riscv64.sh @@ -0,0 +1,129 @@ +#!/bin/bash +# Copyright (c) 2024 SiFive, Inc. All rights reserved. +# Copyright (c) 2024, Phoebe Chen +# Licensed under the MIT License. + + +# The script is a sample for RISC-V 64-bit cross compilation in +# GNU/Linux, and you should ensure that your environment meets +# ORT requirements. You may need to make changes before using it. + +set -e +set -o pipefail + +# Get directory this script is in +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +OS=$(uname -s) + +if [ "$OS" == "Linux" ]; then + LINUX_DISTRO=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') + if [[ "${LINUX_DISTRO}" == "ubuntu" ]] ;then + DIR_OS="Linux" + else + echo "${LINUX_DISTRO} is not supported" + return 1 + fi +else + echo "$OS is not supported" + return 1 +fi + +function cleanup { + if [ -d "$WORK_DIR" ]; then + rm -rf "$WORK_DIR" + fi +} + +# The riscv toolchain, qemu and other platform related settings. +ORT_ROOT_DIR=$DIR/../.. + +PREBUILT_DIR="${ORT_ROOT_DIR}/riscv_tools" + +read -rp "Enter the riscv tools root path(press enter to use default path:${PREBUILT_DIR}): " INPUT_PATH +if [[ "${INPUT_PATH}" ]]; then + PREBUILT_DIR=${INPUT_PATH} +fi +echo "The riscv tool prefix path: ${PREBUILT_DIR}" + +WORK_DIR=$DIR/.prebuilt + +# The prebuit toolchain download from riscv-collab works with Ubuntu. +RISCV_GNU_TOOLCHAIN_URL="https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download" +TOOLCHAIN_VERSION="2023.11.20" +RISCV_TOOLCHAIN_FILE_NAME="riscv64-glibc-ubuntu-22.04-llvm-nightly-2023.11.20-nightly.tar.gz" +RISCV_TOOLCHAIN_FILE_SHA="98d6531b757fac01e065460c19abe8974976c607a8d88631cc5c1529d90ba7ba" + +TOOLCHAIN_PATH_PREFIX=${PREBUILT_DIR} + +execute () { + if ! eval "$1"; then + echo "command:\"$1\" error" + exit 1 + fi +} + +execute "mkdir -p $WORK_DIR" + +# Call the cleanup function when this tool exits. +trap cleanup EXIT + +# Download and install the toolchain from +# https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download +download_file() { + local file_name="$1" + local install_path="$2" + local file_sha="$3" + + echo "Install $1 to $2" + if [[ "$(ls -A "$2")" ]]; then + read -rp "The file already exists. Keep it (y/n)? " replaced + case ${replaced:0:1} in + y|Y ) + echo "Skip download $1." + return + ;; + * ) + rm -rf "$2" + ;; + esac + fi + + echo "Download ${file_name} ..." + mkdir -p "$install_path" + wget --progress=bar:force:noscroll --directory-prefix="${WORK_DIR}" \ + "${RISCV_GNU_TOOLCHAIN_URL}/${TOOLCHAIN_VERSION}/${file_name}" && \ + echo "${file_sha} ${WORK_DIR}/${file_name}" | sha256sum -c - + echo "Extract ${file_name} ..." + tar -C "${install_path}" -xf "${WORK_DIR}/${file_name}" --no-same-owner \ + --strip-components=1 +} + + +read -rp "Install RISCV toolchain(y/n)? " answer +case ${answer:0:1} in + y|Y ) + download_file "${RISCV_TOOLCHAIN_FILE_NAME}" \ + "${TOOLCHAIN_PATH_PREFIX}" \ + "${RISCV_TOOLCHAIN_FILE_SHA}" + ;; + * ) + echo "Skip install RISCV toolchain." + ;; +esac +echo "download finished." + + +# RISC-V cross compilation in GNU/Linux +RISCV_TOOLCHAIN_ROOT=${TOOLCHAIN_PATH_PREFIX} +RISCV_QEMU_PATH=${TOOLCHAIN_PATH_PREFIX}/bin/qemu-riscv64 +python3 "${ORT_ROOT_DIR}"/tools/ci_build/build.py \ + --build_dir "${ORT_ROOT_DIR}/build/${DIR_OS}" \ + --rv64 \ + --parallel \ + --skip_tests \ + --config RelWithDebInfo \ + --cmake_generator=Ninja \ + --riscv_qemu_path="${RISCV_QEMU_PATH}" \ + --riscv_toolchain_root="${RISCV_TOOLCHAIN_ROOT}" "$@" + + From 7dd1f4b8e27f38b55f2430f84ddaae1128bef9f4 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 24 Jan 2024 18:12:04 -0800 Subject: [PATCH 33/45] Pad-18 Cuda implementation (#19211) ### Description Implement Pad-18 for Cuda. ### Motivation and Context Latest models converted by Dynamo fall back on CPU for Pad with performance degradation. This contributes to https://github.com/microsoft/onnx-rewriter/issues/126 --- docs/OperatorKernels.md | 3 +- .../core/providers/cpu/cpu_provider_shared.cc | 8 +- .../core/providers/cpu/cpu_provider_shared.h | 8 +- onnxruntime/core/providers/cpu/tensor/pad.cc | 252 +++++++++--------- .../core/providers/cpu/tensor/padbase.h | 77 +++++- .../providers/cuda/cuda_execution_provider.cc | 38 +-- onnxruntime/core/providers/cuda/tensor/pad.cc | 37 ++- .../providers/rocm/rocm_execution_provider.cc | 26 +- .../provider_bridge_provider.cc | 9 +- 9 files changed, 287 insertions(+), 171 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 31cca232fde34..9d9b266355335 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -682,7 +682,8 @@ Do not modify directly.* |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| |ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index 9c55d37f550f4..bf73c59fb78ca 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -87,7 +87,13 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorShape& indice_shape, const TensorShape& update_shape) override { return ScatterND::ValidateShapes(input_shape, indice_shape, update_shape); } // From cpu/tensor/padbase.h (direct) - Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } + Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } + + void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) override { + PadBase::ComputePads(ctx, data_rank, pads_data, pads); + } + // From cpu/tensor/split.h (direct) Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 8dee1cd620282..f33eec4b93e98 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -25,6 +25,8 @@ class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Pr class contrib__AdamWOptimizerBase__Prepare; class contrib__SGDOptimizerV2Base__Prepare; +using PadsVector = InlinedVector; + struct ProviderHostCPU { // From cpu/tensor/gatherbase.h virtual Status GatherBase__PrepareForCompute(const GatherBase* p, OpKernelContext* context, GatherBase__Prepare& prepare) = 0; @@ -44,7 +46,11 @@ struct ProviderHostCPU { const TensorShape& indice_shape, const TensorShape& update_shape) = 0; // From cpu/tensor/padbase.h - virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) = 0; + virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) = 0; + + virtual void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) = 0; + // From cpu/tensor/split.h virtual Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, diff --git a/onnxruntime/core/providers/cpu/tensor/pad.cc b/onnxruntime/core/providers/cpu/tensor/pad.cc index fe5267f20712b..912280687e229 100644 --- a/onnxruntime/core/providers/cpu/tensor/pad.cc +++ b/onnxruntime/core/providers/cpu/tensor/pad.cc @@ -9,6 +9,8 @@ #include "core/providers/op_kernel_type_control.h" #include "core/util/math.h" +#include + // there's no way to use a raw pointer as the copy destination with std::copy_n // (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset // without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. @@ -167,47 +169,7 @@ ONNX_CPU_OPERATOR_KERNEL( using PadsVector = PadBase::PadsVector; -// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) -template -static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, - size_t block_size, size_t block_count) { - for (size_t block_index = 0; block_index < block_count; block_index++) { - for (size_t i = 0; i < block_size; i++) { - *output++ = *input; - input += input_delta; - } - input += input_pitch; - } -} - -// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, -// and inputPitch and inputDelta are just a single value added each iteration. -template -static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { - for (size_t block_index = 0; block_index < block_count; block_index++) { - *output++ = *input; - input += input_delta; - } -} - -// For constant padding, there is no input, just a size to write the constant to -template -static void PadAxisConstant(T* output, T constant, size_t size) { - if (size == 1) { - *output = constant; - } else if (size == 2) { - *output = constant; - *(output + 1) = constant; - } else { - // This would be faster with SSE instructions. - // That would mean to have an implementation for each type (uint8, uint32, uint64). - T* end = output + size; - for (; output != end;) - *output++ = constant; - } -} - -Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { +Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { switch (mode) { case Mode::Constant: { // default behavior is fine @@ -242,34 +204,66 @@ Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_sh return Status::OK(); } -// special handling for edge case where the input has one or more dims with value of 0 -template -static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, - const Mode& mode, - const TensorShape& input_shape, - TensorShapeVector& output_dims, - T value) { - TensorShape output_shape(output_dims); - ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); - - auto& output_tensor = *ctx->Output(0, output_shape); - - // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty - if (mode == Mode::Constant) { - // we add pads with the default value to all dims including those with a value of 0 - auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); - std::fill_n(output, output_shape.Size(), value); +static void ComputePadWithAxes( + gsl::span pads_tensor_raw_data, + std::function get_axis, + size_t axes_size, + size_t data_rank, + PadsVector& pads) { + for (size_t i = 0; i < axes_size; ++i) { + const size_t axis = onnxruntime::narrow(HandleNegativeAxis(get_axis(i), data_rank)); + pads[axis] = pads_tensor_raw_data[i]; // xi_begin + pads[data_rank + axis] = pads_tensor_raw_data[axes_size + i]; // xi_end } +} - return Status::OK(); +void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) { + pads.reserve(2 * data_rank); + const Tensor* axes_tensor = ctx.Input(3); + if (axes_tensor) { + const size_t num_axes_dims = axes_tensor->Shape().NumDimensions(); + ORT_ENFORCE(num_axes_dims == 1, "Axes tensor should be a 1D tensor "); + + const int64_t num_axes = axes_tensor->Shape().Size(); + ORT_ENFORCE(pads_data.size() == narrow(2 * num_axes), + "Pads tensor size should be equal to twice the number of explicitly provided axes."); + + pads.resize(2 * data_rank, 0); + if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) -> int64_t { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } else if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } + } else { + ORT_ENFORCE(pads_data.size() == 2 * data_rank, + "Pads tensor size should be equal to twice the input dimension count "); + pads.assign(pads_data.begin(), pads_data.end()); + } } // Flatten no padding inner most Axis, so one memcpy cover multiple Axis. // For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as // [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. -static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVector& pads, - const PadsVector& slices, TensorShapeVector& reshaped_dims) { - size_t dims_count = input_dims.size(); +void PadBase::FlattenInnerShape(gsl::span input_dims, gsl::span pads, + gsl::span slices, TensorShapeVector& reshaped_dims) { + const size_t dims_count = input_dims.size(); size_t inner_axis = dims_count - 1; size_t inner_size = 1; @@ -288,14 +282,14 @@ static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVec } while (inner_axis-- > 0); reshaped_dims.reserve(inner_axis + 1); - std::copy(input_dims.cbegin(), input_dims.cbegin() + inner_axis + 1, std::back_inserter(reshaped_dims)); + std::copy(input_dims.begin(), input_dims.begin() + inner_axis + 1, std::back_inserter(reshaped_dims)); // Flatten inner axis. reshaped_dims[inner_axis] = inner_size; } -static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t new_dim_count, - size_t inner_no_pad_size, PadsVector& reshaped_pad) { +void PadBase::ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, + size_t inner_no_pad_size, PadsVector& reshaped_pad) { size_t inner_axis = new_dim_count - 1; std::copy(src_pad.begin(), src_pad.begin() + inner_axis, reshaped_pad.begin()); std::copy(src_pad.begin() + src_dim_count, src_pad.begin() + src_dim_count + inner_axis, @@ -306,6 +300,68 @@ static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t reshaped_pad[inner_axis + new_dim_count] = src_pad[inner_axis + src_dim_count] * inner_no_pad_size; } +// special handling for edge case where the input has one or more dims with value of 0 +template +static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, + const Mode& mode, + const TensorShape& input_shape, + TensorShapeVector& output_dims, + T value) { + TensorShape output_shape(output_dims); + ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); + + auto& output_tensor = *ctx->Output(0, output_shape); + + // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty + if (mode == Mode::Constant) { + // we add pads with the default value to all dims including those with a value of 0 + auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); + std::fill_n(output, output_shape.Size(), value); + } + + return Status::OK(); +} + +// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) +template +static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, + size_t block_size, size_t block_count) { + for (size_t block_index = 0; block_index < block_count; block_index++) { + for (size_t i = 0; i < block_size; i++) { + *output++ = *input; + input += input_delta; + } + input += input_pitch; + } +} + +// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, +// and inputPitch and inputDelta are just a single value added each iteration. +template +static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { + for (size_t block_index = 0; block_index < block_count; block_index++) { + *output++ = *input; + input += input_delta; + } +} + +// For constant padding, there is no input, just a size to write the constant to +template +static void PadAxisConstant(T* output, T constant, size_t size) { + if (size == 1) { + *output = constant; + } else if (size == 2) { + *output = constant; + *(output + 1) = constant; + } else { + // This would be faster with SSE instructions. + // That would mean to have an implementation for each type (uint8, uint32, uint64). + T* end = output + size; + for (; output != end;) + *output++ = constant; + } +} + template static Status PadImpl(OpKernelContext* ctx, const PadsVector& pads, @@ -327,7 +383,7 @@ static Status PadImpl(OpKernelContext* ctx, // Reshape input dims TensorShapeVector reshaped_input_dims; - FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); + PadBase::FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); // Reshape padding size_t new_dims_count = reshaped_input_dims.size(); @@ -336,8 +392,8 @@ static Status PadImpl(OpKernelContext* ctx, ? reshaped_input_dims[inner_axis] / output_dims[inner_axis] : 0); PadsVector reshaped_pad(2 * new_dims_count), reshaped_slice(2 * new_dims_count); - ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); - ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); + PadBase::ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); + PadBase::ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); TensorShapeVector reshaped_output_dims = reshaped_input_dims; TensorShapeVector input_starts; @@ -575,20 +631,6 @@ static PadValue PadValueFromFloat(float value, MLDataType data_type) { return result; } -template -void ComputePadWithAxes( - gsl::span pads_tensor_raw_data, - gsl::span axes_tensor_raw_data, - size_t data_rank, - PadsVector& pads) { - size_t axes_size = axes_tensor_raw_data.size(); - for (size_t i = 0; i < axes_size; ++i) { - int64_t axis = HandleNegativeAxis(onnxruntime::narrow(axes_tensor_raw_data[i]), data_rank); - pads[onnxruntime::narrow(axis)] = pads_tensor_raw_data[i]; // xi_begin - pads[data_rank + onnxruntime::narrow(axis)] = pads_tensor_raw_data[axes_size + i]; // xi_end - } -} - Status Pad::Compute(OpKernelContext* ctx) const { const Tensor& input_tensor = *ctx->Input(0); MLDataType data_type = input_tensor.DataType(); @@ -608,48 +650,14 @@ Status Pad::Compute(OpKernelContext* ctx) const { ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), "Pads tensor should be a 1D tensor of shape [2 * num_axes] " "or a 2D tensor of shape [1, 2 * num_axes]"); - const int64_t* pads_tensor_raw_data = pads_tensor.Data(); - size_t pads_size = static_cast(pads_tensor.Shape().Size()); - pads.reserve(2 * data_rank); - - const Tensor* axes_tensor = ctx->Input(3); - if (axes_tensor) { - const auto& axes_tensor_dims = axes_tensor->Shape().GetDims(); - ORT_ENFORCE(axes_tensor_dims.size() == 1, "Axes tensor should be a 1D tensor "); - int64_t axes_size = axes_tensor_dims[0]; - - pads.resize(2 * data_rank, 0); - if (axes_tensor->IsDataType()) { - const int32_t* axes_tensor_raw_data = axes_tensor->Data(); - ComputePadWithAxes( - {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, - {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, - data_rank, - pads); - } else if (axes_tensor->IsDataType()) { - const int64_t* axes_tensor_raw_data = axes_tensor->Data(); - ComputePadWithAxes( - {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, - {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, - data_rank, - pads); - } - } else { - ORT_ENFORCE(pads_size == 2 * data_rank, - "Pads tensor size should be equal to twice the input dimension count "); - for (size_t i = 0; i < pads_size; ++i) { - pads.push_back(pads_tensor_raw_data[i]); - } - } + + const auto pads_data = pads_tensor.DataAsSpan(); + + // Compute Pads by applying axes if specified otherwise copy the supplied pads. + PadBase::ComputePads(*ctx, data_rank, pads_data, pads); // Separate out any negative pads into the slices array - slices.assign(pads.size(), 0); - for (size_t index = 0; index < pads.size(); index++) { - if (pads[index] < 0) { - slices[index] = pads[index]; - pads[index] = 0; - } - } + PadBase::SeparateNegativeToSlices(pads, slices); value.u64 = 0U; const Tensor* value_tensor = ctx->Input(2); diff --git a/onnxruntime/core/providers/cpu/tensor/padbase.h b/onnxruntime/core/providers/cpu/tensor/padbase.h index d869ed1a6dda2..43f9cbfc9f9a4 100644 --- a/onnxruntime/core/providers/cpu/tensor/padbase.h +++ b/onnxruntime/core/providers/cpu/tensor/padbase.h @@ -19,9 +19,80 @@ class PadBase { // Pads and slices are usually about twice the shapes involved using PadsVector = InlinedVector; - // Update the output_shape to make it consistent with numpy handling where there are one or more dimensions - // in the input_shape with a value of zero. - static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape); + // The following several functions are shared among the providers + + /// + /// Handle the case when the input shape has zero dim values. + /// Depending on the mode, the input dim with zero value must match the output dim value. + /// + /// + /// Padding mode enum value + /// actual input shape + /// output_shape + /// Error if current mode padding can not be achieved with zero dim values + static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape); + + /// + /// Compute Pads by applying axes if specified otherwise copy the supplied pads. + /// + /// The function queries optional axes input (since version 18) and if present, + /// applies it as a mask to the pads. If axes is not present, the pads are copied as is. + /// If axes are present, they are used as a mask over pads, so only those axes are being padded. + /// + /// kernel context to query axes input + /// input rank + /// pads data from pads input + /// resulting pads + static void ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads); + + /// + /// Separates negative pad values to slices and zeros them out in original pads. + /// Leaving the rest of slices values as zero. + /// + /// This function is used inline in the Pad CUDA implementation and is not exposed via a provider + /// interfaces. + /// + /// pad values + /// slices output + static void SeparateNegativeToSlices(gsl::span pads, PadsVector& slices) { + slices.assign(pads.size(), 0); + for (size_t index = 0, lim = pads.size(); index < lim; index++) { + if (pads[index] < 0) { + slices[index] = pads[index]; + pads[index] = 0; + } + } + } + + // End provider shared + + /// + /// Flatten no padding inner most Axis, so one memcpy cover multiple Axis. + /// For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as + /// [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. + /// + /// This is a helper function pads are expected to be twice the rank + /// + /// original input dims + /// pad values + /// slices + /// result dims + static void FlattenInnerShape(gsl::span input_dims, gsl::span pads, + gsl::span slices, TensorShapeVector& reshaped_dims); + + /// + /// Used after the inner shape is flattened, so we can apply this function to pads and slices + /// to reshape them as well. + /// + /// pads + /// original dim count + /// expected flattended dim count + /// is the left most dimension that was flattened. + /// In the example above, that would be 224, reverse computed from 224*3 + /// resulting reshaped pads or slices + static void ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, + size_t inner_no_pad_size, PadsVector& reshaped_pad); protected: PadBase(const OpKernelInfo& info) : value_(info.GetAttrOrDefault("value", 0.f)) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 644bcaaa24cd4..3fc4ed355a12b 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1121,10 +1121,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign); @@ -1269,6 +1269,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -2008,10 +2012,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2091,13 +2095,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2150,11 +2147,22 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 18 BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 4584e5fd8272c..bdd6567d2ef34 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -29,15 +29,27 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 2) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ + 13, 17, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Pad, \ kOnnxDomain, \ - 13, \ + 18, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Pad); @@ -94,28 +106,15 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { if (is_dynamic_) { const Tensor& pads_tensor = *ctx->Input(1); const auto pads_tensor_dims = pads_tensor.Shape().GetDims(); - ORT_ENFORCE(utils::IsPrimitiveDataType(pads_tensor.DataType()), - "Pads tensor should be an INT64 tensor"); ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), - "Pads tensor should be a 1D tensor of shape [2 * input_rank] or a 2D tensor of shape [1, 2 * input_rank]"); + "Pads tensor should be a 1D tensor of shape [2 * num_axes] or a 2D tensor of shape [1, 2 * num_axes]"); - const int64_t* pads_tensor_raw_data = pads_tensor.Data(); - size_t pads_size = static_cast(pads_tensor.Shape().Size()); - ORT_ENFORCE(pads_size == 2 * static_cast(dimension_count), - "Pads tensor size should be equal to twice the input dimension count "); + const auto pads_data = pads_tensor.DataAsSpan(); + + PadBase::ComputePads(*ctx, input_shape.NumDimensions(), pads_data, pads); - pads.reserve(2LL * dimension_count); - for (size_t i = 0; i < pads_size; ++i) { - pads.push_back(pads_tensor_raw_data[i]); - } // Separate out any negative pads into the slices array - slices.resize(pads.size(), 0); - for (size_t index = 0; index < pads.size(); index++) { - if (pads[index] < 0) { - slices[index] = pads[index]; - pads[index] = 0; - } - } + PadBase::SeparateNegativeToSlices(pads, slices); T raw_value{}; const Tensor* value_tensor = ctx->Input(2); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d7bec337a6be4..fff3d14b763d5 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1158,10 +1158,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); @@ -1298,6 +1298,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization); // Opset 18 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2088,10 +2093,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2228,6 +2233,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 18 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index a3155fe6b86cf..e1d0e310425c5 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -547,7 +547,14 @@ Status ScatterND::ValidateShapes(const TensorShape& input_shape, const TensorShape& indice_shape, const TensorShape& update_shape) { return g_host_cpu.ScatterNDBase__ValidateShapes(input_shape, indice_shape, update_shape); } -Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); } +Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { + return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); +} + +void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) { + g_host_cpu.PadBase__ComputePads(ctx, data_rank, pads_data, pads); +} Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, const ConcatBase::InlinedTensorsVector& input_tensors, Prepare& p) const { From 2b87dd373a3567c2c426e2f090b201b8b051a346 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 25 Jan 2024 10:16:41 +0800 Subject: [PATCH 34/45] [ORTModule] Remove Mod from Hash to Avoid Conflict for Triton Code-gen (#19256) Remove mod (10**8) from hash to avoid conflict for Triton code-gen. --- .../python/training/ort_triton/kernel/_mm.py | 20 +++++++++---------- .../training/ort_triton/triton_op_executor.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py index ed92923589d48..a3681a13699a0 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py @@ -11,7 +11,7 @@ import torch from .._cache import ModuleCache, PyCodeCache -from .._utils import next_power_of_2 +from .._utils import gen_unique_name, next_power_of_2 _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1 @@ -305,18 +305,18 @@ def _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name): def _gen_mm_key(dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float) -> int: - return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") % (10**8) + return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") def _gen_mm_module( dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float ) -> Tuple[str, ModuleType]: - func_name = f"mm_{_gen_mm_key(dtype, m, n, k, trans_a, trans_b, alpha)}" + func_name = gen_unique_name("mm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) src_code = _MM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) @@ -333,7 +333,7 @@ def _gen_gemm_key( alpha: float, beta: float, ) -> int: - return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") % (10**8) + return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") def _gen_gemm_module( @@ -348,7 +348,7 @@ def _gen_gemm_module( alpha: float, beta: float, ) -> Tuple[str, ModuleType]: - func_name = f"gemm_{_gen_gemm_key(dtype, m, n, k, stride_cm, stride_cn, trans_a, trans_b, alpha, beta)}" + func_name = gen_unique_name("gemm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) kwargs["stride_cm"] = stride_cm kwargs["stride_cn"] = stride_cn @@ -356,7 +356,7 @@ def _gen_gemm_module( src_code = _GEMM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) @@ -364,13 +364,13 @@ def _gen_gemm_module( def _gen_bmm_key( dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float ) -> int: - return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") % (10**8) + return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") def _gen_bmm_module( dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float ) -> Tuple[str, ModuleType]: - func_name = f"bmm_{_gen_bmm_key(dtype, m, n, k, batch_a, batch_b, trans_a, trans_b, alpha)}" + func_name = gen_unique_name("bmm") kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) batch = batch_a if batch_a >= batch_b else batch_b kwargs["stride_aq"] = m * k if batch_a == batch else 0 @@ -379,7 +379,7 @@ def _gen_bmm_module( src_code = _BMM_TEMPLATE.format(**kwargs) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) - with open(f"triton_debug/{func_name}.py", "w") as f: + with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: f.write(src_code) return func_name, PyCodeCache().load(src_code) diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index 1fe61750e651e..f16abc71251ed 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -67,7 +67,7 @@ def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[in def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: # pylint: disable=unused-argument - return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8) + return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: From 1c92e56dc0f906a43128e2f0c4c6729349aac92b Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 25 Jan 2024 22:28:47 +0800 Subject: [PATCH 35/45] [Cuda] Refactor GroupNorm (#19146) Split GroupNorm implementation into multiple files, to make ROCm EP can reuse cuda code. Related PR: https://github.com/microsoft/onnxruntime/pull/19158 --------- Co-authored-by: Peixuan Zuo --- cmake/onnxruntime_rocm_hipify.cmake | 3 + .../cuda/diffusion/group_norm_common_base.cc | 101 ++++ .../cuda/diffusion/group_norm_common_base.h | 186 ++++++ .../cuda/diffusion/group_norm_impl.cu | 529 +----------------- .../cuda/diffusion/group_norm_impl_kernel.cuh | 355 ++++++++++++ 5 files changed, 653 insertions(+), 521 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h create mode 100644 onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index f70961a66329a..d485abe6bb1a6 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -47,6 +47,9 @@ set(contrib_ops_excluded_files "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" + "diffusion/group_norm_impl_kernel.cuh" + "diffusion/group_norm_common_base.h" + "diffusion/group_norm_common_base.cc" "diffusion/nhwc_conv.cc" "math/gemm_float8.cc" "math/gemm_float8.cu" diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc new file mode 100644 index 0000000000000..5dec690528847 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc @@ -0,0 +1,101 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} + +int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread) { + return NextSize(channels_per_block) / channels_per_thread; +} + +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; + } + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; + } + } + } + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h new file mode 100644 index 0000000000000..84f3403b8d5ae --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -0,0 +1,186 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "core/providers/cuda/cuda_common.h" +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread); + +static inline int32_t DivUp(int32_t m, int32_t n) { + return (m + n - 1) / n; +} + +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor); + +int GetChannelsPerBlock(int num_channels, int num_groups); + +template +struct GroupNormNHWCParams { + // The output buffer. Shape is (n, h, w, c). + T* dst; + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). + T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + + // The gamma scaling factor. + float const* gamma; + + // The beta term to add in GN. + float const* beta; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; + + // The number of instances in the batch. + int32_t n; + + // The height and width of each activation map. + int32_t h; + int32_t w; + + // Number of channels. + int32_t c; + + // Number of groups. + int32_t groups; + + // Do we apply the SiLU activation function? + bool use_silu; + + // Precomputed values and parameters to control the execution of the kernels. + + // Number of activations per instance (h * w) + int32_t hw; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; + + // The precomputed stride between instances. + int32_t hwc; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; + // The precomputed number of groups per block. + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; + + GroupNormNHWCParams(T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); + } + + this->use_silu = use_silu; + this->dst = output; + this->add_out = add_out; + this->src = input; + this->skip = skip; + this->bias = bias; + this->gamma = gamma; + this->beta = beta; + this->group_sum_buffer = reinterpret_cast(workspace); + this->n = batch_size; + this->h = height; + this->w = width; + this->c = num_channels; + this->groups = num_groups; + this->hw = this->h * this->w; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(this->hw, max_blocks_per_hw); + this->hw_per_block = DivUp(this->hw, blocks_per_hw); + + this->channels_per_block = channels_per_block; + this->channels_per_group = channels_per_group; + this->hwc = this->hw * this->c; + this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group); + this->groups_per_block = channels_per_block / this->channels_per_group; + this->epsilon = epsilon; + this->broadcast_skip = broadcast_skip; + + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + this->skip_workspace = (this->add_out != nullptr) ? this->add_out : this->dst; + + this->threads_per_block = GetThreadsPerBlock(channels_per_block, CHANNELS_PER_THREAD); + } +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 48b161552ce0c..d7b2cc2379f4f 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -27,6 +27,8 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" +#include "contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh" using namespace onnxruntime::cuda; @@ -34,329 +36,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -namespace { - -// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. -constexpr static int32_t CHANNELS_PER_THREAD = 2; - -constexpr static int kSizes[] = {128, 256, 320, 384, 512}; -constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); -constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; - -int NextSize(int x) { - for (size_t i = 0; i < kNumOfSizes; ++i) { - if (x <= kSizes[i]) { - return kSizes[i]; - } - } - - return x; -} -} // namespace - -static inline int32_t DivUp(int32_t m, int32_t n) { - return (m + n - 1) / n; -} - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sum_sq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -struct GroupNormNHWCParams { - // The output buffer. Shape is (n, h, w, c). - T* dst; - - // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). - T* add_out; - - // The input buffer. Shape is (n, h, w, c). - T const* src; - - // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). - T const* skip; - - // Optional input buffer for bias tensor. Shape is (c). - T const* bias; - - // The gamma scaling factor. - float const* gamma; - - // The beta term to add in GN. - float const* beta; - - // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. - float* group_sum_buffer; - - // The number of instances in the batch. - int32_t n; - - // The height and width of each activation map. - int32_t h; - int32_t w; - - // Number of channels. - int32_t c; - - // Number of groups. - int32_t groups; - - // Do we apply the SiLU activation function? - bool use_silu; - - // Precomputed values and parameters to control the execution of the kernels. - - // Number of activations per instance (h * w) - int32_t hw; - - // Number of activations per block - int32_t hw_per_block; - - // Number of channels per block in the C dimension. - int32_t channels_per_block; - - // Number of channels per group in the C dimension. - int32_t channels_per_group; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hw*channels_per_group to compute mean of a group. - float inv_hw_channels_per_group; - // The precomputed number of groups per block. - int32_t groups_per_block; - - // Number of threads per block - int32_t threads_per_block; - - // Epsilon to get stable variance in normalization. - float epsilon; - - // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. - bool broadcast_skip; - - // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. - T* skip_workspace; -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); - -template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - - float2 f2 = __half22float2(h2); - - // Update the sum. - sum += f2.x + f2.y; - - // Update the sum of squares. - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - float2 f2 = *reinterpret_cast(&src[offset]); - - // Update the sum. - sum += f2.x + f2.y; - - // Update the sum of squares. - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] -template -inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); - -template <> -inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); - __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); - h2 = h2 + b; - h2 = h2 + s; - - *reinterpret_cast<__half2*>(&add_out[offset]) = h2; - - float2 f2 = __half22float2(h2); - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { - float2 f2 = *reinterpret_cast(&src[offset]); - float2 s = *reinterpret_cast(&skip[skip_offset]); - float2 b = *reinterpret_cast(&bias[bias_offset]); - f2.x += s.x + b.x; - f2.y += s.y + b.y; - - *reinterpret_cast(&add_out[offset]) = f2; - - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] -template -inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); - -template <> -inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); - h2 = h2 + s; - - *reinterpret_cast<__half2*>(&add_out[offset]) = h2; - - float2 f2 = __half22float2(h2); - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template <> -inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { - float2 f2 = *reinterpret_cast(&src[offset]); - float2 s = *reinterpret_cast(&skip[skip_offset]); - f2.x += s.x; - f2.y += s.y; - *reinterpret_cast(&add_out[offset]) = f2; - sum += f2.x + f2.y; - sum_sq += f2.x * f2.x + f2.y * f2.y; -} - -template -__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { - // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage temp_storage; - - // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. - __shared__ float2 smem[THREADS_PER_BLOCK]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { - return; - } - - // The first activation loaded by that block. - int32_t hw_begin = blockIdx.y * params.hw_per_block; - // The last activation loaded by that block. - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - - // The sums. - float sum = 0.F; - float sum_sq = 0.F; - - // Iterate over the activations to compute the sums. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - if (params.skip != nullptr) { - // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) - const int64_t bias_offset = static_cast(ci); - T* add_out = params.skip_workspace; - if (params.broadcast_skip) { - const int64_t skip_offset = static_cast(ni) * params.c + ci; - - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); - } - } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); - } - } - } else { - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); - } - } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); - } - } - } - } else { // GroupNorm - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - UpdateSum(params.src, offset, sum, sum_sq); - } - } - - // The group index relative to the first group within the same block. - int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; - // The channel in the group. - int32_t cj = ci % params.channels_per_group; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - - // Do the segmented scan. InclusiveScan is not deterministic. - GroupSums out; - BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced stores later). - // For each group, only the last thread of that group is picked to save sum to shared memory. - if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { - smem[gi] = make_float2(out.sum, out.sum_sq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groups_per_block) { - return; - } - - // The global group index. - // Use neighboring threads for coalesced write. - int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; - - if (gj < params.groups) { - float2 sums = smem[threadIdx.x]; - const int index = (2 * ni) * params.groups + gj; - atomicAdd(¶ms.group_sum_buffer[index], sums.x); - atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); - } -} - template void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; @@ -390,102 +69,6 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) } } -template -__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu); - -template <> -__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu) { - // Fetch two channels per thread. - __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); - - // Extract the two half values. - float2 f2 = __half22float2(h2); - - // Normalize the channels. - f2.x = (f2.x - mean) * inv_std_dev; - f2.y = (f2.y - mean) * inv_std_dev; - - // Scale by gamma and add beta. - f2.x = gamma_f2.x * f2.x + beta_f2.x; - f2.y = gamma_f2.y * f2.y + beta_f2.y; - - // Apply SiLU activation if needed. - if (silu) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); -} - -template <> -__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, - float2& gamma_f2, float2& beta_f2, bool silu) { - // Fetch two channels per thread. - float2 f2 = *reinterpret_cast(&src[offset]); - - // Normalize the channels. - f2.x = (f2.x - mean) * inv_std_dev; - f2.y = (f2.y - mean) * inv_std_dev; - - // Scale by gamma and add beta. - f2.x = gamma_f2.x * f2.x + beta_f2.x; - f2.y = gamma_f2.y * f2.y + beta_f2.y; - - // Apply SiLU activation if needed. - if (silu) { - f2.x = f2.x * sigmoid(f2.x); - f2.y = f2.y * sigmoid(f2.y); - } - - *reinterpret_cast(&dst[offset]) = f2; -} - -template -__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on. - int32_t gi = ci / params.channels_per_group; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sum_sq = 0.F; - if (gi < params.groups) { - const int index = (2 * ni) * params.groups + gi; - sum = params.group_sum_buffer[index]; - sum_sq = params.group_sum_buffer[index + params.groups]; - } - - // Load gamma/beta. Fetch two per thread. - float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); - - // Compute the mean. - float mean = sum * params.inv_hw_channels_per_group; - // Compute the variance. - float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); - // Compute the inverse of the stddev. - float inv_std_dev = rsqrtf(var + params.epsilon); - - int32_t hw_begin = blockIdx.y * params.hw_per_block; - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - - const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); - } -} - template void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; @@ -517,60 +100,6 @@ void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea } } -int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { - int32_t max_divisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { - max_divisor = divisor1; - } - if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { - max_divisor = divisor2; - } - } - } - return max_divisor; -} - -// Find proper channels per block based on a cost function: The cost is number of channels corresponding to -// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has -// work to do so it is ideal case. -int FindChannelsPerBlock(int num_channels, int channels_per_group) { - int min_cost = -1; - int best_candidate = -1; - for (size_t i = kNumOfSizes; i > 0; --i) { - if (kSizes[i - 1] < channels_per_group) { - break; - } - - int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; - int blocks = (num_channels + channels_per_block - 1) / channels_per_block; - int cost = blocks * kSizes[i - 1] - num_channels; - if (cost == 0) { - return channels_per_block; - } - - if (min_cost == -1 || cost < min_cost) { - min_cost = cost; - best_candidate = channels_per_block; - } - } - - return best_candidate; -} - -int GetChannelsPerBlock(int num_channels, int num_groups) { - int32_t channels_per_group = num_channels / num_groups; - int32_t channels_per_block = channels_per_group; - if (channels_per_group < kMaxSize / 2) { - channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); - } - return channels_per_block; -} - template Status LaunchGroupNormKernel( cudaStream_t stream, @@ -591,19 +120,13 @@ Status LaunchGroupNormKernel( bool use_silu, bool broadcast_skip, int channels_per_block) { - GroupNormNHWCParams params; - - int32_t channels_per_group = num_channels / num_groups; - // channels_per_block is computed in PrePack. - // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. - if (channels_per_block < channels_per_group) { - channels_per_block = GetChannelsPerBlock(num_channels, num_groups); - } + GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, + batch_size, num_channels, height, width, num_groups, use_silu, + broadcast_skip, channels_per_block); - // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases - if (channels_per_block % channels_per_group != 0 || - channels_per_block > kMaxSize || - (channels_per_group % CHANNELS_PER_THREAD != 0)) { + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "GroupNorm in CUDA does not support the input: n=", batch_size, " h=", height, @@ -612,42 +135,6 @@ Status LaunchGroupNormKernel( " groups=", num_groups); } - params.use_silu = use_silu; - params.dst = output; - params.add_out = add_out; - params.src = input; - params.skip = skip; - params.bias = bias; - params.gamma = gamma; - params.beta = beta; - params.group_sum_buffer = reinterpret_cast(workspace); - params.n = batch_size; - params.h = height; - params.w = width; - params.c = num_channels; - params.groups = num_groups; - params.hw = params.h * params.w; - - // This will allocate as many blocks as possible to partition HW. - // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. - // TODO: tune this logic to find proper blocks when hw is small. - constexpr int32_t max_blocks_per_hw = 1024; - const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); - params.hw_per_block = DivUp(params.hw, blocks_per_hw); - - params.channels_per_block = channels_per_block; - params.channels_per_group = channels_per_group; - params.hwc = params.hw * params.c; - params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); - params.groups_per_block = channels_per_block / params.channels_per_group; - params.epsilon = epsilon; - params.broadcast_skip = broadcast_skip; - - // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. - params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; - - params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; - CUDA_RETURN_IF_ERROR(cudaMemsetAsync( params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh new file mode 100644 index 0000000000000..081e9a3de578c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh @@ -0,0 +1,355 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/diffusion/group_norm_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sum_sq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); + +template <> +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Update the sum. + sum += f2.x + f2.y; + + // Update the sum of squares. + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } + + // The first activation loaded by that block. + int32_t hw_begin = blockIdx.y * params.hw_per_block; + // The last activation loaded by that block. + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + + // The sums. + float sum = 0.F; + float sum_sq = 0.F; + + // Iterate over the activations to compute the sums. + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); + } + } + + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; + + // Do the segmented scan. InclusiveScan is not deterministic. + GroupSums out; + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groups_per_block) { + return; + } + + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } +} + +template +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); + +template <> +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; + + // Scale by gamma and add beta. + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; + + // Apply SiLU activation if needed. + if (silu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); +} + +template <> +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { + // Fetch two channels per thread. + float2 f2 = *reinterpret_cast(&src[offset]); + + // Normalize the channels. + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; + + // Scale by gamma and add beta. + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; + + // Apply SiLU activation if needed. + if (silu) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + *reinterpret_cast(&dst[offset]) = f2; +} + +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } + + // The instance in the batch. + int32_t ni = blockIdx.z; + + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sum_sq = 0.F; + if (gi < params.groups) { + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; + } + + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); + + // Compute the mean. + float mean = sum * params.inv_hw_channels_per_group; + // Compute the variance. + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); + // Compute the inverse of the stddev. + float inv_std_dev = rsqrtf(var + params.epsilon); + + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime From 5b065050734e6bc397dc38ba0df246aeb57ac508 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Fri, 26 Jan 2024 00:25:35 +0800 Subject: [PATCH 36/45] [js/webgpu] Fix Tanh explosion (#19201) ### Description ```math \tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}= \left\{ \begin{array}{cc} -\frac{1-e^{-2\cdot(-x)}}{1+e^{-2\cdot(-x)}}, & x<0 \\ 0, & x=0 \\ \frac{1-e^{-2x}}{1+e^{-2x}}, & x>0 \end{array} \right. ``` ### Motivation and Context On some platforms, $$\tanh(1000)=\frac{e^{1000}-e^{-1000}}{e^{1000}+e^{-1000}}$$ would produce NaN instead of 0.999... or 1 (imagine $e^{1000}=\infty$ and $\frac{\infty}{\infty}$ explodes). --- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 4 +++- js/web/test/data/ops/tanh.jsonc | 26 +++++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 js/web/test/data/ops/tanh.jsonc diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 82311d72e58b9..76929efb32537 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -279,7 +279,9 @@ export const tan = (context: ComputeContext): void => { }; export const tanh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', 'tanh')); + // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/test/data/ops/tanh.jsonc b/js/web/test/data/ops/tanh.jsonc new file mode 100644 index 0000000000000..f7691535bd71c --- /dev/null +++ b/js/web/test/data/ops/tanh.jsonc @@ -0,0 +1,26 @@ +[ + { + "name": "tanh with no attributes", + "operator": "Tanh", + "attributes": [], + "cases": [ + { + "name": "T[2,4]", + "inputs": [ + { + "data": [-1000, -1, 0, 0.1, 0.2, 0.3, 0.4, 1000], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1, -0.761594, 0, 0.099668, 0.197375, 0.291313, 0.379949, 1], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 373b3c645df57..56db28b0a379c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1389,6 +1389,7 @@ "sub.jsonc", "sub_int32.jsonc", "tan.jsonc", + "tanh.jsonc", "tile.jsonc", "transpose.jsonc", "transpose_int32_uint32.jsonc", From 2b285cd78a629971a9e465036e94a431e6fef17b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Jan 2024 09:30:15 -0800 Subject: [PATCH 37/45] [CUDA] Add functions to dump bfloat16 tensors (#19266) ### Description GroupQueryAttention add BFloat16 in https://github.com/microsoft/onnxruntime/pull/19095, and there is build error when enable dumping. This supports print bfloat16 tensor to console. --- .../cuda/transformers/dump_cuda_tensor.cc | 88 ++++++++++++------- .../cuda/transformers/dump_cuda_tensor.h | 27 ++++-- 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index b31f5d243e001..4cfa89a4d58c2 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -203,23 +203,19 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } -void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); -} - -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { - Print(name, reinterpret_cast(tensor), dim0, dim1); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { @@ -227,9 +223,14 @@ void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { +void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, true); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); +} + +void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); } void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { @@ -242,6 +243,11 @@ void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int d DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } +void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); +} + void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); @@ -252,22 +258,31 @@ void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, i DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, true); } -void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { - Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { +void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); } -void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { - if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { + Print(name, reinterpret_cast(tensor), dim0, dim1); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { + Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); +} + +void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { + Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); } void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { @@ -301,43 +316,52 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, } #else -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { } void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int) const { } +void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { +} + void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int, int) const { } -void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { +void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { +} + +void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { } void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int, int) const { diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h index 264ecd7cfe2f5..773401f79531a 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h @@ -16,20 +16,31 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons public: CudaTensorConsoleDumper() = default; virtual ~CudaTensorConsoleDumper() {} - void Print(const char* name, const float* tensor, int dim0, int dim1) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; - void Print(const char* name, const half* tensor, int dim0, int dim1) const; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const float* tensor, int dim0, int dim1) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + + void Print(const char* name, const half* tensor, int dim0, int dim1) const; void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; - void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; - void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; + + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const; + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const; + void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; + void Print(const char* name, const Tensor& value) const override; void Print(const char* name, const OrtValue& value) const override; void Print(const char* name, int index, bool end_line) const override; From a2867b911e67146218b4fc0b32721e5cdbade49b Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 25 Jan 2024 11:51:39 -0800 Subject: [PATCH 38/45] [TensorRT EP] Fix mem leak for TRT plugins custom ops (#19248) TRT EP's GetTensorRTCustomOpDomainList() will create vector of OrtCustomOpDomain objects and release the ownership of those objects. But, thoses objects are not released forever. In session level, we need to make TRT EP remember what OrtCustomOpDomain objects it created and release them at EP destruction time. --- .../tensorrt/tensorrt_execution_provider.cc | 18 +++++-- .../tensorrt_execution_provider_custom_ops.cc | 37 +++++--------- .../core/session/provider_bridge_ort.cc | 49 +++---------------- .../python/onnxruntime_pybind_state.cc | 6 +-- 4 files changed, 35 insertions(+), 75 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index fe6b959b962de..39e5f5be000e5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1834,13 +1834,21 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { } void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { - if (info_.custom_op_domain_list.empty()) { - common::Status status = CreateTensorRTCustomOpDomainList(info_); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + std::string extra_plugin_lib_paths{""}; + if (info_.has_trt_options) { + if (!info_.extra_plugin_lib_paths.empty()) { + extra_plugin_lib_paths = info_.extra_plugin_lib_paths; } + } else { + const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); + if (!extra_plugin_lib_paths_env.empty()) { + extra_plugin_lib_paths = extra_plugin_lib_paths_env; + } + } + auto status = CreateTensorRTCustomOpDomainList(custom_op_domain_list, extra_plugin_lib_paths); + if (status != Status::OK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; } - custom_op_domain_list = info_.custom_op_domain_list; } // Check the graph is the subgraph of control flow op diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 4e466a5d568a6..eb340ba1e64b6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -27,8 +27,12 @@ extern TensorrtLogger& GetTensorrtLogger(); * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { - std::unique_ptr custom_op_domain = std::make_unique(); - custom_op_domain->domain_ = "trt.plugins"; + static std::unique_ptr custom_op_domain = std::make_unique(); + static std::vector> created_custom_op_list; + if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { + domain_list.push_back(custom_op_domain.get()); + return Status::OK(); + } // Load any extra TRT plugin library if any. // When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry. @@ -69,38 +73,19 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& continue; } - std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); - trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); + created_custom_op_list.push_back(std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up + created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName()); + custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get()); registered_plugin_names.insert(plugin_name); } - domain_list.push_back(custom_op_domain.release()); + custom_op_domain->domain_ = "trt.plugins"; + domain_list.push_back(custom_op_domain.get()); } catch (const std::exception&) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; } return Status::OK(); } -common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) { - std::vector domain_list; - std::string extra_plugin_lib_paths{""}; - if (info.has_trt_options) { - if (!info.extra_plugin_lib_paths.empty()) { - extra_plugin_lib_paths = info.extra_plugin_lib_paths; - } - } else { - const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); - if (!extra_plugin_lib_paths_env.empty()) { - extra_plugin_lib_paths = extra_plugin_lib_paths_env; - } - } - auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); - if (!domain_list.empty()) { - info.custom_op_domain_list = domain_list; - } - return Status::OK(); -} - void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { if (domain != nullptr) { for (auto ptr : domain->custom_ops_) { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3178c13d30eec..f48110aa7ee5b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1713,17 +1713,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessi ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN - auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(device_id); - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths"); - AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); - - return nullptr; + OrtTensorRTProviderOptionsV2 tensorrt_options; + tensorrt_options.device_id = device_id; + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &tensorrt_options); API_IMPL_END } @@ -1741,33 +1733,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtS ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) { API_IMPL_BEGIN - - std::shared_ptr factory; - -#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) - auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; - // If EP context configs are provided in session options, we need to propagate them to provider options - if (ep_context_cache_enabled_from_sess_options) { - OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); - - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted); - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); - } else { - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); - } -#else - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); -#endif - - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - AddTensorRTCustomOpDomainToSessionOption(options, ""); - - return nullptr; + OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &trt_options_converted); API_IMPL_END } @@ -1906,11 +1873,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, // if provider options already have the EP context configs provided, the configs in session options will be ignored // since provider options has higher priority than session options. if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) { - // We need to create a new provider options V2 object and copy from provider_options, due to the "const" object pointed by provider_options can't be modified. - // Note: No need to worry about tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will + // This function might need to update the "const" OrtTensorRTProviderOptionsV2 object which can't be modified. + // Therefore, we need to create a new OrtTensorRTProviderOptionsV2 object and copy from tensorrt_options and use this new object to create the factory instead. + // Note: No need to worry about new_tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options); factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); } else { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f7ed5520727db..8e13982ca6861 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -443,9 +443,9 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti if (it != options.end()) { trt_extra_plugin_lib_paths = it->second; } - std::vector domain_list; - tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths); - for (auto ptr : domain_list) { + std::vector custom_op_domains; + tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths); + for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { so.custom_op_domains_.push_back(ptr); } else { From 656ca66186c7fd362abd8f33915bd0f96483bf43 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 26 Jan 2024 07:37:05 +0800 Subject: [PATCH 39/45] [js/webgpu] Support uniforms for conv, conv transpose, conv grouped (#18753) --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 125 +++++++------ .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 154 ++++++++-------- .../ops/3rd-party/conv_backprop_webgpu.ts | 174 +++++++++++------- .../ops/3rd-party/matmul_packed_webgpu.ts | 108 +++++------ .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 86 +++++---- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 15 +- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 18 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 39 ++-- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 43 +++-- 9 files changed, 418 insertions(+), 344 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 3638938df7dbe..1a03621512888 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; import {getActivationSnippet} from '../fuse-utils'; @@ -88,10 +88,10 @@ const conv2dCommonSnippet = let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; - let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; + let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels); + let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]); + let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0]; + let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1]; let xCh = ${col} % inChannels; var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for @@ -108,7 +108,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : @@ -117,7 +117,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); @@ -129,9 +129,8 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} } @@ -142,7 +141,7 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; @@ -181,31 +180,46 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; - const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - const fitAOuter = dimAOuter % tileAOuter === 0; const fitBOuter = dimBOuter % tileBOuter === 0; const fitInner = dimInner % tileInner === 0; - const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; - const t = tensorTypeToWsglStorageType(inputs[0].dataType); - // TODO: support component 2, 3. - const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = - inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); - const inputVariables = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, + {type: 'int32', data: attributes.dilations} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, + {name: 'dilation', type: 'i32', length: 2} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } - let declareFunctions = ` + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } @@ -213,51 +227,50 @@ export const createConv2DMatMulProgramInfo = let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; - if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); - - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` + const x = inputVariable( + 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms, - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + } + + return ` ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)} - const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); - const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); - const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], elementsSize[2], t)} - ${ + ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}` + sequentialAccessByThreads)}`; + }; + return { + name: 'Conv2DMatMul', + shaderCache: { + hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ + tileAOuter};${tileBOuter};${tileInner}`, + inputDependencies + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index d425155857e14..33e50a9a39cb9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {getActivationSnippet} from '../fuse-utils'; @@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; - const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readASnippet = ` - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); - let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + let WRow = ${col} / (uniforms.filter_dims[1] * inChannels); + let WCol = ${col} / inChannels % uniforms.filter_dims[1]; + let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]); + let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { return ${type}(0.0); } @@ -103,25 +103,25 @@ const conv2dTransposeCommonSnippet = const sampleA = isChannelsLast ? ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readASnippet} } return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readASnippet} } return ${type}(0.0);`; const sampleW = ` let col = colIn * ${innerElementSize}; - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); - let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; + let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels); + let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1]; if (${ - isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : - 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { + isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' : + 'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -129,9 +129,8 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - const {activationFunction, applyActivation} = getActivationSnippet(attributes, type); + const applyActivation = getActivationSnippet(attributes, type); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } @@ -142,7 +141,7 @@ const conv2dTransposeCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueInput; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} @@ -186,65 +185,64 @@ export const createConv2DTransposeMatMulProgramInfo = const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - const inputVariables = [x, w]; - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const effectiveFilterDims = [ + filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2) + ]; - let declareFunctions = ''; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, + {type: 'int32', data: filterDims}, {type: 'int32', data: pads} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { - return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; - }`; + inputDependencies.push('rank'); } - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DTransposeMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; + + let declareFunctions = ''; + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2}, + {name: 'filter_dims', type: 'i32', length: filterDims.length}, + {name: 'pads', type: 'i32', length: pads.length} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` ${utilFunctions('uniforms.result_strides')} - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)}; - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ - attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${ - attributes.pads[1] + attributes.pads[3]})/2); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} ${ @@ -252,6 +250,18 @@ export const createConv2DTransposeMatMulProgramInfo = elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, - undefined, sequentialAccessByThreads)}` + undefined, sequentialAccessByThreads)}`; + }; + + return { + name: 'Conv2DTransposeMatMul', + shaderCache: + {hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 50b0841a0200a..380efc8bc577a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -20,24 +20,18 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, - dataType: string): string => { - const isChannelsLast = attributes.format === 'NHWC'; + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean, + is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType, + isChannelsLast = false): string => { const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; - const outputSize = ShapeUtil.size(outputShape); const workPerThread = isVec4 ? 2 : 1; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { @@ -50,20 +44,21 @@ const createConvTranspose2DOpProgramShaderSource = }`; } const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components); + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); const inputVariables = [dy, w]; if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); } - const output = outputVariable('result', inputs[0].dataType, outputShape, components); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const codeSnippet4 = `{ - let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1]; - let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1]; + let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; + let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; - let dyCorner = vec2(i32(r), i32(c)) - vec2(pads); + let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. @@ -71,29 +66,29 @@ const createConvTranspose2DOpProgramShaderSource = for (var i = 0; i < ${workPerThread}; i++) { dotProd[i] = vec4<${dataType}>(0.0); } - for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); - let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || + for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); + let wRPerm = uniforms.filter_dims[0] - 1 - wR; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); - let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims[1] - 1 - wC; + for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -101,7 +96,7 @@ const createConvTranspose2DOpProgramShaderSource = let idyC: u32 = u32(dyC); let idyC2: u32 = u32(dyC2); if (bDyCVal && bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -123,7 +118,7 @@ const createConvTranspose2DOpProgramShaderSource = dot(xValue, wValue3)); } } else if (bDyCVal) { - let d2Length = outBackprop[${channelDim}]; + let d2Length = uniforms.Dy_shape[${channelDim}]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -138,7 +133,7 @@ const createConvTranspose2DOpProgramShaderSource = dotProd[0] = dotProd[0] + tmpval; } } else if (bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -167,39 +162,39 @@ const createConvTranspose2DOpProgramShaderSource = let d1 = ${output.indicesGet('outputIndices', channelDim)}; let r = ${output.indicesGet('outputIndices', rowDim)}; let c = ${output.indicesGet('outputIndices', colDim)}; - let dyCorner = vec2(i32(r), i32(c)) - pads; + let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; let dyRCorner = dyCorner.x; let dyCCorner = dyCorner.y; - let groupId = d1 / ${outputChannelsPerGroup}; - let wOutChannel = d1 - groupId * ${outputChannelsPerGroup}; + let groupId = d1 / uniforms.output_channels_per_group; + let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. var dotProd = ${dataType}(0.0); - for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) { - if (wR % dilations.x != 0) { + for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { + if (wR % uniforms.dilations.x != 0) { continue; } - let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); - let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); + let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < effectiveFilterDims.y; wC = wC + 1) { - if (wC % dilations.y != 0) { + for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { + if (wC % uniforms.dilations.y != 0) { continue; } - let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - var inputChannel = groupId * ${inputChannelsPerGroup}; - for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { + var inputChannel = groupId * uniforms.input_channels_per_group; + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; @@ -214,27 +209,11 @@ const createConvTranspose2DOpProgramShaderSource = `; return ` - ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const dilations : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); + ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; ${isVec4 ? codeSnippet4 : codeSnippet}}`; }; @@ -257,19 +236,72 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const isChannelsLast = attributes.format === 'NHWC'; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + const strides = [attributes.strides[0], attributes.strides[1]]; + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const dilations = [attributes.dilations[0], attributes.dilations[1]]; + const effectiveFilterDims = [ + filterDims[0] + + (attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + + (attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2 + ]; + + const isVec4 = false; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[0] / group; + const outputChannelsPerGroup = wShape[1]; + + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims}, + {type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads}, + {type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) + ]; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + + const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length}, + {name: 'filter_dims', type: 'u32', length: filterDims.length}, + {name: 'dilations', type: 'u32', length: filterDims.length}, + {name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length}, + {name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return `${ + createConvTranspose2DOpProgramShaderSource( + shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms, + isChannelsLast)}`; + }; return { name: 'ConvTranspose2D', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies}, getRunData: () => ({ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType - }] + }], + programUniforms }), - getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, - dataType), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 47ec16a296712..ee71110245252 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -112,14 +112,14 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { let inputRow = tileRow + innerRow; @@ -204,7 +204,7 @@ export const makeMatMulPackedSource = let globalColStart = i32(workgroupId.x) * ${tileBOuter}; // Loop over shared dimension. - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) { for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) { @@ -260,7 +260,7 @@ let tileRowA = i32(localId.y) * ${rowPerThreadA}; let tileColA = i32(localId.x) * ${colPerThreadA}; let tileRowB = i32(localId.y) * ${rowPerThreadB}; // Loop over shared dimension. -for (var t = 0; t < numTiles; t = t + 1) { +for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) { for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) { @@ -322,7 +322,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${ + splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; @@ -379,7 +380,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimAOuter && col < uniforms.dimInner) + if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -391,7 +392,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimInner && col < uniforms.dimBOuter) + if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -401,7 +402,7 @@ const matMulReadWriteFnSource = fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let coords = vec3(batch, row, colIn); ${ @@ -422,16 +423,10 @@ export const createMatmulProgramInfo = isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; - const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const enableBatchUniforms = enableShapesUniforms(outerDims.length); - const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); - const dimAOuter = aShape[aShape.length - 2]; const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; @@ -446,72 +441,67 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); - const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; - + const aShapeOrRank = aShapeTemp.length; const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); - const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; - + const bShapeOrRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - - const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); - const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); - const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); - const inputVariables = [A, B]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (enableBatchUniforms) { - programUniforms.push(...createTensorShapeVariables(outerDims)); + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); } - if (enableAShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(aShapeTemp)); - } - if (enableBShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(bShapeTemp)); - } - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), + ...createTensorShapeVariables(bShapeTemp)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length > 2; - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); - const declareFunctions = matMulReadWriteFnSource( - components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], - isChannelsLast); if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); } programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchShapeOrRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); + const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); + const inputVariables = [A, B]; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + } + const uniforms: UniformsArrayType = + [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); + return ` ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${declareFunctions} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} `; - // TODO: turn clipMax and clipMin to uniforms. + }; return { name: 'MatMul', shaderCache: { - hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + - `${isVec4}` + - `${isChannelsLast}`, + hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`, inputDependencies }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 21b4953d3f90c..f81d6577890c5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -3,9 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ProgramInfo, ProgramUniform} from '../types'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {getActivationSnippet} from './fuse-utils'; @@ -27,52 +27,75 @@ export const createGroupedConvProgramInfo = xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', inputs[0].dataType, outputShape); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); - const x = inputVariable('x', inputs[0].dataType, xShape); - const w = inputVariable('w', inputs[1].dataType, wShape); - const inputVars = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations}, + {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, + {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), + ...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const strides: vec2 = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u); - const pads: vec2 = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u); - - ${shaderHelper.declareVariables(...inputVars, output)} + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const applyActivation = getActivationSnippet(attributes, output.type.value); + const x = inputVariable('x', inputs[0].dataType, xShape.length); + const w = inputVariable('w', inputs[1].dataType, wShape.length); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + } - ${activationFunction} + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length}, + {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let outputIndices = ${output.offsetToIndices('global_idx')}; let batch: u32 = outputIndices[0]; let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}]; let xRCCorner: vec2 = vec2(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${ - isChannelLast ? 2 : 3}]) * strides - pads; - let group_id: u32 = output_channel / ${outputChannelsPerGroup}u; + isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads; + let group_id: u32 = output_channel / uniforms.output_channels_per_group; var value: ${output.type.value} = ${output.type.value}(0); - for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) { - let input_channel = group_id * ${wShape[1]}u + wInChannel; - for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) { - let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u; + for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) { + let input_channel = group_id * uniforms.w_shape[1] + wInChannel; + for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) { + let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0]; - if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) { + if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) { continue; } - for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) { - let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u; - if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) { + for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) { + let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1]; + if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) { continue; } let xVal = ${ - isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : - x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; + isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : + x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')}; value += xVal*wVal; } @@ -82,15 +105,17 @@ export const createGroupedConvProgramInfo = ${applyActivation} ${output.setByOffset('global_idx', 'value')} }`; + }; return { name: 'GroupedConv', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies}, getRunData: () => ({ outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType }], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; @@ -114,7 +139,7 @@ export const createGroupedConvVectorizeProgramInfo = const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); + const applyActivation = getActivationSnippet(attributes, output.type.value); const x = inputVariable('x', inputs[0].dataType, xShape.length, components); const w = inputVariable('w', inputs[1].dataType, wShape.length, components); const inputVars = [x, w]; @@ -129,7 +154,6 @@ export const createGroupedConvVectorizeProgramInfo = .registerUniform('strides', 'i32', 2) .registerUniform('pads', 'i32', 2) .declareVariables(...inputVars, output)} - ${activationFunction} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let width0 = uniforms.output_shape[3]; @@ -179,7 +203,7 @@ export const createGroupedConvVectorizeProgramInfo = return { name: 'GroupedConv-Vectorize', shaderCache: { - hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 32b1d52ed94ca..33d16754c737a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -2,7 +2,6 @@ // Licensed under the MIT License. import {TensorView} from '../../tensor-view'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; @@ -59,7 +58,6 @@ export interface ConvTransposeAttributes extends ConvAttributes { readonly outputShape: readonly number[]; } - const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); @@ -96,11 +94,7 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - const cacheKey = attributes.cacheKey + [ - kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), - dilations.join(',') - ].join('_'); - Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides}); return newAttributes; }; @@ -119,7 +113,7 @@ export const parseConvTransposeAttributes = (attributes: Record const wIsConst = (attributes.wIsConst as () => boolean)(); const outputPadding = attributes.outputPadding as [number, number, number, number]; const outputShape = attributes.outputShape as [number, number]; - return createAttributeWithCacheKey({ + return { autoPad, format, dilations, @@ -130,8 +124,9 @@ export const parseConvTransposeAttributes = (attributes: Record pads, strides, wIsConst, - ...activationAttributes - }); + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 7af2c5db49f40..5afec0389fac8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -3,7 +3,7 @@ import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; @@ -110,7 +110,7 @@ const getAdjustedConvAttributes = (attributes: T, inpu // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); + Object.assign(newAttributes, {kernelShape, pads}); return newAttributes; }; @@ -126,8 +126,18 @@ export const parseConvAttributes = (attributes: Record): ConvAt const strides = attributes.strides as [number, number]; const wIsConst = (attributes.w_is_const as () => boolean)(); - return createAttributeWithCacheKey( - {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); + return { + autoPad, + format, + dilations, + group, + kernelShape, + pads, + strides, + wIsConst, + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 0b5c0db2b5112..2e0aa33a957dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -7,30 +7,21 @@ export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; - readonly activationCacheKey: string; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): - {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`}; - case 'Sigmoid': - return { - activationFunction: '', - applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));` - }; - case 'Clip': - return { - activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${ - attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; + // TODO: adding other activations that can be fused. + default: + return ''; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { @@ -38,7 +29,7 @@ export const parseInternalActivationAttributes = if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; - return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; + return {activation, clipMax, clipMin}; } - return {activation, activationCacheKey: activation}; + return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index de9309d1e436f..c946ea6366123 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -6,7 +6,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = @@ -27,11 +27,19 @@ export const createNaiveMatmulProgramInfo = const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); const batchSize = ShapeUtil.size(outerDims); const outputShapeInShader = [batchSize, M, N]; + const programUniforms: ProgramUniform[] = [ {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), - ...createTensorShapeVariables(bShape) + {type: 'uint32', data: K} ]; + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), + ...createTensorShapeVariables(bShape)); if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); } @@ -42,7 +50,7 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { @@ -57,6 +65,14 @@ export const createNaiveMatmulProgramInfo = const outerDimsB = bShape.slice(0, -2); const broadCastADims = getBroadcastDims(outerDimsA, outerDims); const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'K', type: 'u32'} + ]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; const name = variable.name; @@ -96,15 +112,10 @@ export const createNaiveMatmulProgramInfo = return ` ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('M', 'u32') - .registerUniform('N', 'u32') - .registerUniform('K', 'u32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let col = (global_idx % (uniforms.N / ${components})) * ${components}; var index1 = global_idx / (uniforms.N / ${components}); let stride1 = uniforms.M / ${outputNumber}; @@ -134,8 +145,7 @@ export const createNaiveMatmulProgramInfo = return { name: 'MatMulNaive', shaderCache: { - hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${ - isChannelsLast}`, + hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] }, getRunData: () => ({ @@ -166,9 +176,8 @@ export const matMul = (context: ComputeContext): void => { const N = outputShape[outputShape.length - 1]; const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; if (N < 8 && K < 8) { - context.compute( - createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } }; From 8b4517218b52285efaaf8badd303c00b0e514238 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Jan 2024 16:57:58 -0800 Subject: [PATCH 40/45] Remove USE_CUTLASS flag (#19271) ### Description Since Cutlass can be built with CUDA 11.4 (The minimum CUDA version for onnxruntime CUDA build), there is no need to have a flag to disable cutlass. Changes: (1) Reverted https://github.com/microsoft/onnxruntime/pull/18761 (2) remove the condition to build cutlass. (3) Fix a few build errors or warnings during testing CUDA 11.4 build. Note that SM 89 and 90 (including fp8) requires CUDA 11.8 or later. Flash attention and cutlass fused multihead attention will not be built for CUDA < 11.6. It is recommended to use CUDA 11.8 or above to build if you want to support latest GPUs. It is better to include it in 1.17.0 (otherwise, the release branch might encounter build failure with CUDA 11.4). Tests: (1) Build with flash attention and efficient attention off: **passed** (2) Build with CUDA 11.4: **passed** Example build command used in Ubuntu 20.04: ``` export CUDA_HOME=/usr/local/cuda-11.4 export CUDNN_HOME=/usr/lib/x86_64-linux-gnu/ export CUDACXX=/usr/local/cuda-11.4/bin/nvcc sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_version 11.4 \ --cuda_home $CUDA_HOME --cudnn_home $CUDNN_HOME --build_wheel --skip_tests \ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ --disable_types float8 ``` ### Motivation and Context --- cmake/CMakeLists.txt | 23 ++++++------------- cmake/external/cutlass.cmake | 20 ++++++++-------- .../cuda/collective/sharded_moe.cc | 4 ---- .../contrib_ops/cuda/collective/sharded_moe.h | 4 ---- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 8 ------- .../cuda/moe/ft_moe/compute_occupancy.h | 5 ---- .../cuda/moe/ft_moe/cutlass_heuristic.cc | 11 ++++----- .../cuda/moe/ft_moe/cutlass_heuristic.h | 2 -- .../cuda/moe/ft_moe/epilogue_helpers.h | 4 ---- .../cuda/moe/ft_moe/ft_gemm_configs.h | 4 ---- .../moe/ft_moe/gemm_moe_problem_visitor.h | 4 ---- .../cuda/moe/ft_moe/layout_traits_helper.h | 6 +---- .../cuda/moe/ft_moe/moe_cutlass_kernel.h | 4 ---- .../cuda/moe/ft_moe/moe_gemm_kernels.h | 4 ---- .../moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu | 4 ---- .../moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu | 4 ---- .../moe/ft_moe/moe_gemm_kernels_template.h | 4 ---- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 4 ---- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 6 +---- .../cuda/moe/ft_moe/moe_problem_visitor.h | 4 ---- .../cuda/moe/ft_moe/tile_interleaved_layout.h | 5 ---- onnxruntime/contrib_ops/cuda/moe/moe.cc | 4 ---- onnxruntime/contrib_ops/cuda/moe/moe.h | 4 ---- onnxruntime/contrib_ops/cuda/moe/moe_base.h | 4 ---- .../cuda/quantization/matmul_nbits.cu | 6 ++--- onnxruntime/test/contrib_ops/moe_test.cc | 4 ---- 26 files changed, 25 insertions(+), 131 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 7d7304630c00e..0eb224623f678 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -97,7 +97,6 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) -cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) @@ -707,20 +706,16 @@ if (onnxruntime_USE_CUDA) enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") + if (onnxruntime_DISABLE_CONTRIB_OPS) + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) - message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") - set(onnxruntime_USE_CUTLASS OFF) + message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() else() - set(onnxruntime_USE_CUTLASS OFF) -endif() - -if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS) - if (onnxruntime_DISABLE_CONTRIB_OPS) - message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled") - else() - message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled") - endif() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -906,10 +901,6 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() - if (onnxruntime_USE_CUTLASS) - target_compile_definitions(${target_name} PRIVATE USE_CUTLASS) - endif() - if(USE_NEURAL_SPEED) target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) endif() diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index efc708bd681c0..f04f4bec76cd5 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,13 +1,11 @@ -if (onnxruntime_USE_CUTLASS) - include(FetchContent) - FetchContent_Declare( - cutlass - URL ${DEP_URL_cutlass} - URL_HASH SHA1=${DEP_SHA1_cutlass} - ) +include(FetchContent) +FetchContent_Declare( + cutlass + URL ${DEP_URL_cutlass} + URL_HASH SHA1=${DEP_SHA1_cutlass} +) - FetchContent_GetProperties(cutlass) - if(NOT cutlass_POPULATED) - FetchContent_Populate(cutlass) - endif() +FetchContent_GetProperties(cutlass) +if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) endif() diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 9b989dac9a94b..40a667ffd5d83 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" @@ -204,5 +202,3 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h index cbd483fddab78..5ea4ae59c4020 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -36,5 +34,3 @@ class ShardedMoE final : public NcclKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index fa73950c9c6f5..8f368251f12c7 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -70,10 +70,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); @@ -169,10 +167,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); @@ -272,10 +268,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -377,10 +371,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h index 9b97690fe70fd..86136ea244e23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -13,9 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifdef USE_CUTLASS - #pragma once #include @@ -52,5 +49,3 @@ inline int compute_occupancy_for_kernel() { } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index f0abd46572a90..adc043e5689e2 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifdef USE_CUTLASS #include "cutlass_heuristic.h" @@ -66,9 +65,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, } // Check that the workspace has sufficient space for this split-k factor - const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); - const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); - const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + const size_t ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const size_t ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; if (required_ws_bytes > workspace_bytes) { return false; @@ -128,7 +127,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector= multi_processor_count * 256 ? 1 : split_k_limit; - for (int ii = 0; ii < candidate_configs.size(); ++ii) { + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); int occupancy = occupancies[ii]; @@ -186,5 +185,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector @@ -64,5 +62,3 @@ class MoeGemmRunner { }; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu index 1d0dfe7c5a647..1d9a249db4237 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -14,12 +14,8 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu index 7a5d97902ee8f..7b250e6ca9060 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -14,12 +14,8 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 3fd0fc47055a5..66950c9b65970 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -14,8 +14,6 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - // Ignore CUTLASS warnings about type punning #ifdef __GNUC__ #pragma GCC diagnostic push @@ -428,5 +426,3 @@ void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, con } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 9232e8d012933..f4f2b49032d23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -16,8 +16,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include #include #include @@ -900,5 +898,3 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half cudaStream_t); } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index f09471de1cc2e..5cc2a3f79f003 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -16,8 +16,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "moe_gemm_kernels.h" @@ -174,6 +172,4 @@ class CutlassMoeFCRunner> { } // namespace layout } // namespace cutlass - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 0da06192e266b..3f26a274109ad 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "moe.h" @@ -119,5 +117,3 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index 710b914f0633d..c4d8c4dc64c57 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -26,5 +24,3 @@ class MoE final : public CudaKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index dc8b9d57f79f6..f55a7cde2e208 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "core/common/common.h" @@ -172,5 +170,3 @@ class MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 67384957d8dd2..d4d583906b7f4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -89,7 +89,7 @@ __device__ __forceinline__ void Convert8xInt4To8xHalfs(uint32_t value, half2* ha asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(kOneSixteenth), "r"(kNeg64)); } -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -120,7 +120,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, sums_half2[3] = sums_half2[3] + v3 * (*(reinterpret_cast(&(vec_permuted.w)))); } #else -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -144,7 +144,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, } #endif -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { float4 a_vec_0 = *(reinterpret_cast(a)); float4 a_vec_1 = *(reinterpret_cast(a + 4)); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 844cc877f2568..ebb0261deefa5 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" @@ -423,5 +421,3 @@ TEST(MoETest, MoETest_Relu) { } // namespace test } // namespace onnxruntime - -#endif From a3f0e2422b5eb2968e3f11e93414aa1661b32e2f Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 26 Jan 2024 08:58:22 +0800 Subject: [PATCH 41/45] [js/webgpu] Support f16 uniform (#19098) ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/backend-webgpu.ts | 26 +++++++++--- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 40 +++++++++++++------ js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 4 +- js/web/lib/wasm/jsep/webgpu/types.ts | 2 +- .../core/providers/js/operators/pad.cc | 10 ++--- 5 files changed, 56 insertions(+), 26 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 8ca025d66550c..a48fe99570abf 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -428,13 +428,26 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const baseAlignment = data.length <= 2 ? data.length * 4 : 16; + const sizeOfElement = v.type === 'float16' ? 2 : 4; + let sizeOfVecOrMat; + let baseAlignment; + if (v.type === 'float16') { + baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); + sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; + } else { + baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16; + sizeOfVecOrMat = 16; + } currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; offsets.push(currentOffset); - // When data.length > 4, the uniform variable is of type array,N>, where N = - // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * - // SizeOf(vec4). - currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4; + // For non-float16 type, when data.length > 4, the uniform variable is of type array,N>, where + // N = Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * + // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type + // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte + // length is N * SizeOf(mat2x4). + const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; + currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : + data.length * sizeOfElement; }); // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set @@ -449,6 +462,9 @@ export class WebGpuBackend { new Int32Array(arrayBuffer, offset, data.length).set(data); } else if (v.type === 'uint32') { new Uint32Array(arrayBuffer, offset, data.length).set(data); + } else if (v.type === 'float16') { + // TODO: use Float16Array. + new Uint16Array(arrayBuffer, offset, data.length).set(data); } else { new Float32Array(arrayBuffer, offset, data.length).set(data); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index bc3265be955f0..643744108c0f4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -330,18 +330,28 @@ export const sumVector = (name: string, components: number) => { * @param name - the name of variable. * @param index - the index of variable element. * @param length - the length of variable. + * @param type - the type of variable, optional. */ -export const getElementAt = (name: string, index: number|string, length: number): string => { - if (name.startsWith('uniforms.') && length > 4) { - if (typeof (index) === 'string') { - return `${name}[(${index}) / 4][(${index}) % 4]`; - } else { - return `${name}[${Math.floor(index / 4)}][${index % 4}]`; - } - } else { - return length > 1 ? `${name}[${index}]` : name; - } -}; +export const getElementAt = + (name: string, index: number|string, length: number, type?: UniformDataElementType): string => { + if (name.startsWith('uniforms.') && length > 4) { + if (typeof (index) === 'string') { + if (type === 'f16') { + return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`; + } else { + return `${name}[(${index}) / 4][(${index}) % 4]`; + } + } else { + if (type === 'f16') { + return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`; + } else { + return `${name}[${Math.floor(index / 4)}][${index % 4}]`; + } + } + } else { + return length > 1 ? `${name}[${index}]` : name; + } + }; /** * A helper function to get a IndicesHelper for a given input or output. @@ -688,7 +698,7 @@ export const internalVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components); -export type UniformDataElementType = 'u32'|'f32'|'i32'; +export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32'; export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; /** @@ -861,7 +871,11 @@ class ShaderHelperImpl implements ShaderHelper { const uniformSnippets: string[] = []; for (const {name, type, length} of this.uniforms) { if (length && length > 4) { - uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + if (type === 'f16') { + uniformSnippets.push(`@align(16) ${name}:array, ${Math.ceil(length / 8)}>`); + } else { + uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + } } else { const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`; uniformSnippets.push(`${name}:${typeTemp}`); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index eca3fa7d944bb..c65b741e1105a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 1) { throw new Error('Too few inputs'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Input type must be float.'); + if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) { + throw new Error('Input type must be float or float16.'); } if (inputs.length >= 2) { diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index e55bfb6ba9f16..789ac70a6913a 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -24,7 +24,7 @@ export interface TensorInfo { } export interface ProgramUniform { - type: 'int32'|'float32'|'uint32'; + type: 'int32'|'float16'|'float32'|'uint32'; data: number|readonly number[]; } diff --git a/onnxruntime/core/providers/js/operators/pad.cc b/onnxruntime/core/providers/js/operators/pad.cc index 24ba85cbf6e0d..83fee35481aa6 100644 --- a/onnxruntime/core/providers/js/operators/pad.cc +++ b/onnxruntime/core/providers/js/operators/pad.cc @@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 2, 10, kJsExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Pad); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 17, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 18, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX( 19, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), From 358650d4415d930ba3ea4de159b8191cb1696dc4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Jan 2024 17:19:04 -0800 Subject: [PATCH 42/45] Fix BigModel stable diffusion pipeline (#19277) ### Description Fix two issues: (1) We can only use single quote inside `bash -c "..."`. Current pipeline job stopped at `python3 demo_txt2img.py astronaut` and skip the following commands. In this change, we remove the remaining commands to get same effect (otherwise, the pipeline runtime might be 2 hours instead of 15 minutes). (2) Fix a typo of Stable. --- .../github/azure-pipelines/bigmodels-ci-pipeline.yml | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index ff2e7c0468a21..b767b7276b428 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -136,11 +136,11 @@ stages: - template: templates/explicitly-defined-final-tasks.yml -- stage: Stale_Diffusion +- stage: Stable_Diffusion dependsOn: - Build_Onnxruntime_Cuda jobs: - - job: Stale_Diffusion + - job: Stable_Diffusion variables: skipComponentGovernanceDetection: true CCACHE_DIR: $(Pipeline.Workspace)/ccache @@ -171,12 +171,7 @@ stages: python3 -m pip install -r requirements-cuda11.txt; \ python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com; \ echo Generate an image guided by a text prompt; \ - python3 demo_txt2img.py "astronaut riding a horse on mars"; \ - echo Generate an image with Stable Diffusion XL guided by a text prompt; \ - python3 demo_txt2img_xl.py 'starry night over Golden Gate Bridge by van gogh'; \ - python3 demo_txt2img_xl.py --enable-refiner 'starry night over Golden Gate Bridge by van gogh'; \ - echo Generate an image guided by a text prompt using LCM LoRA; \ - python3 demo_txt2img_xl.py --scheduler LCM --lora-weights latent-consistency/lcm-lora-sdxl --denoising-steps 4 "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"; \ + python3 demo_txt2img.py 'astronaut riding a horse on mars'; \ popd; \ " displayName: 'Run stable diffusion demo' From fc44f96ad523526b23d5e6851bd89f888e0de2bc Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 25 Jan 2024 21:55:36 -0800 Subject: [PATCH 43/45] Add support for a collection of OrtValue as inputs and outputs to C# TrainingSession (#19048) --- .../Training/TrainingSession.shared.cs | 107 ++++++++++++++++++ .../TrainingTest.cs | 75 ++++++++++++ 2 files changed, 182 insertions(+) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 877677dcad57b..fec0d46e96dfb 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -282,6 +282,48 @@ public IDisposableReadOnlyCollection TrainStep( } } + /// + /// This function performs a training step that computes the outputs of the training model and the gradients + /// of the trainable parameters for the given OrtValue inputs. The train step is performed based on the training model + /// that was provided to the training session. + /// The TrainStep method is equivalent of running forward propagation and backward propagation in a single + /// step. + /// The gradients computed are stored inside the training session state so they can be later consumed + /// by the OptimizerStep function. + /// The gradients can be lazily reset by invoking the LazyResetGrad function. + /// Example usage: + /// + /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); + /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); + /// List inputValues = new List { x, label }; + /// using (var loss = trainingSession.TrainStep(inputValues)) + /// { + /// // process output values + /// } + /// + /// + /// Specify a collection of that indicates the input values to the training model. + /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. + public IDisposableReadOnlyCollection TrainStep(IReadOnlyCollection inputValues) + { + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); + IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount]; + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, + inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray)); + + + var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); + try + { + return CreateDisposableResult(disposableHandles); + } + finally + { + disposableHandles.Dispose(); + } + } + /// /// Convert native OrtValue handles to OrtValue instances /// in an exceptions safe manner. @@ -370,6 +412,42 @@ public void EvalStep( inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } + /// + /// This function performs an eval step that computes the outputs of the eval model for the given inputs. + /// Inputs are expected to be of type OrtValue. The eval step is performed based on the eval model that was + /// provided to the training session. + /// Example usage: + /// + /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); + /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); + /// List inputValues = new List { x, label }; + /// using (var loss = trainingSession.EvalSteps(inputValues)) + /// { + /// // process output values + /// } + /// + /// + /// Specify a collection of that indicates the input values to the eval model. + public IDisposableReadOnlyCollection EvalStep(IReadOnlyCollection inputValues) + { + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); + IntPtr[] outputValuesArray = new IntPtr[(int)_evalOutputCount]; + + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, + inputValuesArray, (UIntPtr)_evalOutputCount, outputValuesArray)); + + + var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); + try + { + return CreateDisposableResult(disposableHandles); + } + finally + { + disposableHandles.Dispose(); + } + } + /// /// Sets the learning rate for this training session. @@ -702,6 +780,35 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection v return valuesArray; } + private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection inputValues) + { + var valuesArray = new IntPtr[inputValues.Count]; + for (int index = 0; index < inputValues.Count; ++index) + { + valuesArray[index] = inputValues.ElementAt(index).Handle; + } + return valuesArray; + } + + private static IDisposableReadOnlyCollection CreateDisposableResult(DisposableOrtValueHandleArray disposableHandles) + { + var outputValues = new DisposableList(disposableHandles.Span.Length); + try + { + for (int i = 0; i < disposableHandles.Span.Length; i++) + { + outputValues.Add(new OrtValue(disposableHandles.Span[i])); + disposableHandles.Span[i] = IntPtr.Zero; + } + return outputValues; + } + catch (Exception) + { + outputValues.Dispose(); + throw; + } + } + private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection names, DisposableList cleanupList) { cleanupList.Capacity += names.Count; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index 68b1d5bcc6147..9b72326201322 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -612,6 +612,81 @@ public void TestUpdateParameter() } } + [Fact(DisplayName = "TestTrainingSessionTrainStepWithOrtValues")] + public void TestTrainingSessionTrainStepWithOrtValues() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, optimizerPath); + cleanUp.Add(trainingSession); + + float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); + var expectedOutputDimensions = new int[] { 1 }; + float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); + long[] inputShape = { 2, 784 }; + Int32[] labelsData = { 1, 1 }; + long[] labelsShape = { 2 }; + + using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); + var inputValues = new List { inputOrtValue, labelsOrtValue }; + + using (var results = trainingSession.TrainStep(inputValues)) + { + Assert.Single(results); + var outputOrtValue = results[0]; + Assert.True(outputOrtValue.IsTensor); + var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); + Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); + } + } + } + + [Fact(DisplayName = "TestTrainingSessionEvalStepWithOrtValues")] + public void TestTrainingSessionEvalStepWithOrtValues() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + using (var cleanUp = new DisposableListTest()) + { + var state = CheckpointState.LoadCheckpoint(checkpointPath); + cleanUp.Add(state); + Assert.NotNull(state); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + + var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); + cleanUp.Add(trainingSession); + + float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); + var expectedOutputDimensions = new int[] { 1 }; + float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); + long[] inputShape = { 2, 784 }; + Int32[] labelsData = { 1, 1 }; + long[] labelsShape = { 2 }; + + using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); + using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); + var inputValues = new List { inputOrtValue, labelsOrtValue }; + + using (var results = trainingSession.EvalStep(inputValues)) + { + Assert.Single(results); + var outputOrtValue = results[0]; + Assert.True(outputOrtValue.IsTensor); + var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); + Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; From 7d4dc66846aadb2daf63fea3504aff0c596d1d38 Mon Sep 17 00:00:00 2001 From: cao lei Date: Fri, 26 Jan 2024 07:39:08 -0800 Subject: [PATCH 44/45] ExecutionProvider API refactor - make GenerateMetaDefId a standalone function, decouple it from EP (#18977) ### Description Make EP's member function, GenerateMetaDefId, a standalone function which decouples from EP ### Motivation and Context This change is for ExecutionProvider API refactoring, we will make a clean ExecutionProvider API first for later EPv2 work --- .../core/framework/execution_provider.h | 35 +-------- .../core/framework/execution_provider.cc | 73 ------------------ .../framework/model_metadef_id_generator.cc | 75 +++++++++++++++++++ .../framework/model_metadef_id_generator.h | 31 ++++++++ .../providers/cann/cann_execution_provider.cc | 6 +- .../providers/cann/cann_execution_provider.h | 1 + .../coreml/coreml_execution_provider.cc | 4 +- .../coreml/coreml_execution_provider.h | 2 + .../providers/dnnl/dnnl_execution_provider.cc | 13 ++-- .../providers/dnnl/dnnl_execution_provider.h | 1 + .../providers/js/js_execution_provider.cc | 2 +- .../migraphx/migraphx_execution_provider.cc | 6 +- .../migraphx/migraphx_execution_provider.h | 1 + .../nnapi_builtin/nnapi_execution_provider.cc | 4 +- .../nnapi_builtin/nnapi_execution_provider.h | 2 + .../core/providers/partitioning_utils.h | 2 +- .../providers/qnn/qnn_execution_provider.cc | 4 +- .../providers/qnn/qnn_execution_provider.h | 2 + .../providers/shared_library/provider_api.h | 3 +- .../provider_bridge_provider.cc | 4 - .../shared_library/provider_interfaces.h | 7 +- .../shared_library/provider_wrappedtypes.h | 10 ++- .../tensorrt/tensorrt_execution_provider.cc | 2 +- .../tensorrt_execution_provider_utils.h | 10 ++- .../webnn/webnn_execution_provider.cc | 4 +- .../webnn/webnn_execution_provider.h | 2 + .../xnnpack/xnnpack_execution_provider.cc | 2 +- .../core/session/provider_bridge_ort.cc | 10 ++- .../test/framework/execution_provider_test.cc | 8 +- onnxruntime/test/framework/tunable_op_test.cc | 2 +- .../internal_testing_execution_provider.cc | 4 +- .../internal_testing_execution_provider.h | 2 + 32 files changed, 187 insertions(+), 147 deletions(-) create mode 100644 onnxruntime/core/framework/model_metadef_id_generator.cc create mode 100644 onnxruntime/core/framework/model_metadef_id_generator.h diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 1de0217c7e1fa..31c988f500779 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -59,14 +59,11 @@ enum class DataLayout { class IExecutionProvider { protected: - IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false) - : IExecutionProvider(type, OrtDevice(), use_metadef_id_creator) {} + IExecutionProvider(const std::string& type) + : IExecutionProvider(type, OrtDevice()) {} - IExecutionProvider(const std::string& type, OrtDevice device, bool use_metadef_id_creator = false) + IExecutionProvider(const std::string& type, OrtDevice device) : default_device_(device), type_{type} { - if (use_metadef_id_creator) { - metadef_id_generator_ = std::make_unique(); - } } /* @@ -274,19 +271,6 @@ class IExecutionProvider { return logger_; } - /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance. - The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models. - @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph. - @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model. - This is created using the model path if available, - or the model input names and the output names from all nodes in the main graph. - @remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches - compiled kernels, so the name must be unique and deterministic across models and sessions. - NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and - virtual, and ModelMetadefIdGenerator but be defined in the header as well. - */ - virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const; - virtual std::unique_ptr GetProfiler() { return {}; } @@ -340,18 +324,5 @@ class IExecutionProvider { // It will be set when this object is registered to a session const logging::Logger* logger_ = nullptr; - - // helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across - // multiple sessions. - class ModelMetadefIdGenerator { - public: - int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash); - - private: - std::unordered_map main_graph_hash_; // map graph instance hash to model contents hash - std::unordered_map model_metadef_id_; // current unique id for model - }; - - std::unique_ptr metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index 7f8009216ce3a..b39924d4c3ff9 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -35,77 +35,4 @@ common::Status IExecutionProvider::Compile(const std::vector& } #endif - -int IExecutionProvider::ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer, - HashValue& model_hash) { - model_hash = 0; - - // find the top level graph - const Graph* cur_graph = &graph_viewer.GetGraph(); - while (cur_graph->IsSubgraph()) { - cur_graph = cur_graph->ParentGraph(); - } - - uint32_t instance_hash[4] = {0, 0, 0, 0}; - - const Graph& main_graph = *cur_graph; - - // hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use - // the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique - // fingerprint for the instance that can use used as the key to the hash of the model path/contents. - MurmurHash3::x86_128(&main_graph, gsl::narrow_cast(sizeof(Graph)), instance_hash[0], &instance_hash); - HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32); - - // if we've already hashed this main graph instance use the cached value - auto entry = main_graph_hash_.find(graph_instance_hash); - if (entry != main_graph_hash_.cend()) { - model_hash = entry->second; - } else { - uint32_t hash[4] = {0, 0, 0, 0}; - - // prefer path the model was loaded from - // this may not be available if the model was loaded from a stream or in-memory bytes - const auto& model_path_str = main_graph.ModelPath().ToPathString(); - if (!model_path_str.empty()) { - MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast(model_path_str.size()), hash[0], &hash); - } else { - auto hash_str = [&hash](const std::string& str) { - MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); - }; - - // fingerprint the main graph by hashing graph inputs and the ordered outputs from each node - for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) { - hash_str(node_arg->Name()); - } - - // note: process nodes in order defined in model to be deterministic - for (const auto& node : main_graph.Nodes()) { - for (const auto* node_arg : node.OutputDefs()) { - if (node_arg->Exists()) { - hash_str(node_arg->Name()); - } - } - } - } - - model_hash = hash[0] | (uint64_t(hash[1]) << 32); - - main_graph_hash_[graph_instance_hash] = model_hash; - } - - // return the current unique id, and increment to update - return model_metadef_id_[model_hash]++; -} - -int IExecutionProvider::GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const { - ORT_ENFORCE(metadef_id_generator_, - "IExecutionProvider constructor must be called with true for use_metadef_id_creator"); - - // if the EP is shared across multiple sessions there's a very small potential for concurrency issues. - // use a lock when generating an id to be paranoid - static OrtMutex mutex; - std::lock_guard lock(mutex); - return metadef_id_generator_->GenerateId(graph_viewer, model_hash); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/model_metadef_id_generator.cc b/onnxruntime/core/framework/model_metadef_id_generator.cc new file mode 100644 index 0000000000000..e51c6ebc29975 --- /dev/null +++ b/onnxruntime/core/framework/model_metadef_id_generator.cc @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include "model_metadef_id_generator.h" +#include "core/platform/ort_mutex.h" +#include "core/graph/graph_viewer.h" +#include "core/framework/murmurhash3.h" + +namespace onnxruntime { +int ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer, + HashValue& model_hash) const { + // if the EP is shared across multiple sessions there's a very small potential for concurrency issues. + // use a lock when generating an id to be paranoid + static OrtMutex mutex; + std::lock_guard lock(mutex); + model_hash = 0; + + // find the top level graph + const Graph* cur_graph = &graph_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + + uint32_t instance_hash[4] = {0, 0, 0, 0}; + + const Graph& main_graph = *cur_graph; + + // hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use + // the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique + // fingerprint for the instance that can use used as the key to the hash of the model path/contents. + MurmurHash3::x86_128(&main_graph, gsl::narrow_cast(sizeof(Graph)), instance_hash[0], &instance_hash); + HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32); + + // if we've already hashed this main graph instance use the cached value + auto entry = main_graph_hash_.find(graph_instance_hash); + if (entry != main_graph_hash_.cend()) { + model_hash = entry->second; + } else { + uint32_t hash[4] = {0, 0, 0, 0}; + + // prefer path the model was loaded from + // this may not be available if the model was loaded from a stream or in-memory bytes + const auto& model_path_str = main_graph.ModelPath().ToPathString(); + if (!model_path_str.empty()) { + MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast(model_path_str.size()), hash[0], &hash); + } else { + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // fingerprint the main graph by hashing graph inputs and the ordered outputs from each node + for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) { + hash_str(node_arg->Name()); + } + + // note: process nodes in order defined in model to be deterministic + for (const auto& node : main_graph.Nodes()) { + for (const auto* node_arg : node.OutputDefs()) { + if (node_arg->Exists()) { + hash_str(node_arg->Name()); + } + } + } + } + + model_hash = hash[0] | (uint64_t(hash[1]) << 32); + + main_graph_hash_[graph_instance_hash] = model_hash; + } + + // return the current unique id, and increment to update + return model_metadef_id_[model_hash]++; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/model_metadef_id_generator.h b/onnxruntime/core/framework/model_metadef_id_generator.h new file mode 100644 index 0000000000000..82f68c42b5c35 --- /dev/null +++ b/onnxruntime/core/framework/model_metadef_id_generator.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "core/common/basic_types.h" +namespace onnxruntime { +class GraphViewer; + +/// +/// helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across +/// multiple sessions. +/// +class ModelMetadefIdGenerator { + public: + /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance. + The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models. + @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph. + @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model. + This is created using the model path if available, + or the model input names and the output names from all nodes in the main graph. + */ + int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const; + + private: + // mutable as these are caches so we can minimize the hashing required on each usage of GenerateId + mutable std::unordered_map main_graph_hash_; // map graph instance hash to model contents hash + mutable std::unordered_map model_metadef_id_; // current unique id for model +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 127c37bd84d0f..752b742805a7c 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -9,7 +9,6 @@ #include #include -#include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/cann/cann_execution_provider.h" @@ -1029,13 +1028,14 @@ Status RegisterCANNKernels(KernelRegistry& kernel_registry) { } // namespace cann CANNExecutionProvider::CANNExecutionProvider(const CANNExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kCannExecutionProvider, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, info.device_id), true}, info_{info} { + : IExecutionProvider{onnxruntime::kCannExecutionProvider, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_{info} { InitProviderOrtApi(); CANN_CALL_THROW(aclrtSetDevice(info_.device_id)); soc_name_ = aclrtGetSocName(); ORT_ENFORCE(soc_name_ != nullptr, "aclrtGetSocName return nullptr"); + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); } CANNExecutionProvider::~CANNExecutionProvider() { @@ -1197,7 +1197,7 @@ std::unique_ptr CANNExecutionProvider::GetSubGraph( // Generate unique kernel name for CANN subgraph HashValue model_hash = 0; - int id = GenerateMetaDefId(graph_viewer, model_hash); + int id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); auto meta_def = IndexedSubGraph_MetaDef::Create(); meta_def->name() = graph_viewer.Name() + "_" + std::to_string(model_hash) + "_" + std::to_string(id); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 76d3d9c331563..63ae980869c65 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -81,6 +81,7 @@ class CANNExecutionProvider : public IExecutionProvider { std::unordered_map modelIDs_; std::unordered_map models_; std::unordered_map> names_; + std::unique_ptr metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index c9973671ffa28..c133f7b82aba4 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -24,7 +24,7 @@ namespace onnxruntime { constexpr const char* COREML = "CoreML"; CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) - : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider, true}, + : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, coreml_flags_(coreml_flags) { } @@ -54,7 +54,7 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const auto gen_metadef_name = [&]() { HashValue model_hash; - int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); return MakeString(COREML, "_", model_hash, "_", metadef_id); }; diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 67050e8079cf9..0201739547dd1 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -4,6 +4,7 @@ #pragma once #include "core/framework/execution_provider.h" +#include "core/framework/model_metadef_id_generator.h" #include "core/providers/coreml/coreml_provider_factory.h" namespace onnxruntime { @@ -34,5 +35,6 @@ class CoreMLExecutionProvider : public IExecutionProvider { #ifdef __APPLE__ std::unordered_map> coreml_models_; #endif + ModelMetadefIdGenerator metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index 05eb0091a8c83..3271dab13f675 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -5,8 +5,6 @@ #pragma warning(disable : 4996) #endif -#include "core/providers/dnnl/dnnl_execution_provider.h" - #include #include #include @@ -16,6 +14,7 @@ #include "core/platform/ort_mutex.h" #include "core/providers/shared_library/provider_api.h" +#include "core/providers/dnnl/dnnl_execution_provider.h" #include "core/providers/dnnl/dnnl_fwd.h" #include "core/providers/dnnl/dnnl_node_capability.h" @@ -30,7 +29,7 @@ constexpr const char* DNNL = "Dnnl"; constexpr const char* DNNL_CPU = "DnnlCpu"; DnnlExecutionProvider::DnnlExecutionProvider(const DnnlExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kDnnlExecutionProvider, true}, + : IExecutionProvider{onnxruntime::kDnnlExecutionProvider}, info_(info) { InitProviderOrtApi(); @@ -77,8 +76,8 @@ DnnlExecutionProvider::DnnlExecutionProvider(const DnnlExecutionProviderInfo& in // Log the number of threads used LOGS_DEFAULT(INFO) << "Allocated " << omp_get_max_threads() << " OpenMP threads for oneDNN ep\n"; #endif // defined(DNNL_OPENMP) - -} // namespace onnxruntime + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); +} DnnlExecutionProvider::~DnnlExecutionProvider() { } @@ -229,7 +228,7 @@ std::vector> DnnlExecutionProvider::GetCapabi // Assign inputs and outputs to subgraph's meta_def HashValue model_hash; - int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); auto meta_def = ::onnxruntime::IndexedSubGraph_MetaDef::Create(); meta_def->name() = "DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); meta_def->domain() = kMSDomain; @@ -264,7 +263,7 @@ std::vector> DnnlExecutionProvider::GetCapabi graph_viewer.ToProto(*model_proto->mutable_graph(), false, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); HashValue model_hash; - int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); std::fstream dump("DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id) + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); model_proto->SerializeToOstream(dump); } diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index 41062ccb4bc1b..b7fcbb7765180 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -41,6 +41,7 @@ class DnnlExecutionProvider : public IExecutionProvider { bool debug_log_ = false; // enable fusion by default bool enable_fusion_ = true; + std::unique_ptr metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index af9658271d210..0448487e6faec 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -682,7 +682,7 @@ std::unique_ptr RegisterKernels() { using namespace js; JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) - : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, + : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, preferred_data_layout_{info.data_layout} { } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 8bfa66710e2fc..40e76a0a67782 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -102,7 +102,7 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c } MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id), true}, device_id_(info.device_id) { + : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, device_id_(info.device_id) { InitProviderOrtApi(); // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(device_id_)); @@ -165,6 +165,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv MIOPEN_CALL_THROW(miopenCreate(&external_miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(external_miopen_handle_, stream_)); + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); + LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " << "device_id: " << device_id_ << ", migraphx_fp16_enable: " << fp16_enable_ @@ -757,7 +759,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st // Generate unique kernel name for MIGraphX subgraph uint64_t model_hash = 0; - int id = GenerateMetaDefId(graph, model_hash); + int id = metadef_id_generator_->GenerateId(graph, model_hash); std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(id); auto meta_def = IndexedSubGraph_MetaDef::Create(); const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph"; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index c094be51012e4..d582338c7e067 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -98,6 +98,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { AllocatorPtr allocator_; miopenHandle_t external_miopen_handle_ = nullptr; rocblas_handle external_rocblas_handle_ = nullptr; + std::unique_ptr metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 727917ad9232e..b04703d7611ee 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -50,7 +50,7 @@ std::unordered_set GetPartitioningStopOps(const optional& partitioning_stop_ops_list) - : IExecutionProvider{onnxruntime::kNnapiExecutionProvider, true}, + : IExecutionProvider{onnxruntime::kNnapiExecutionProvider}, nnapi_flags_(nnapi_flags), partitioning_stop_ops_(GetPartitioningStopOps(partitioning_stop_ops_list)) { nnapi_handle_ = NnApiImplementation(); @@ -176,7 +176,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view const auto gen_metadef_name = [&]() { HashValue model_hash; - int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); return MakeString(NNAPI, "_", model_hash, "_", metadef_id); }; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index e4911511e6db0..460616c41991f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -6,6 +6,7 @@ #include "core/common/inlined_containers_fwd.h" #include "core/common/optional.h" #include "core/framework/execution_provider.h" +#include "core/framework/model_metadef_id_generator.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h" #include "core/providers/nnapi/nnapi_provider_factory.h" @@ -48,5 +49,6 @@ class NnapiExecutionProvider : public IExecutionProvider { const NnApi* nnapi_handle_ = nullptr; nnapi::DeviceWrapperVector nnapi_target_devices_; nnapi::TargetDeviceOption target_device_option_; + ModelMetadefIdGenerator metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/partitioning_utils.h b/onnxruntime/core/providers/partitioning_utils.h index f9d5f7403f17b..136725c2f7250 100644 --- a/onnxruntime/core/providers/partitioning_utils.h +++ b/onnxruntime/core/providers/partitioning_utils.h @@ -40,7 +40,7 @@ using OnGroupClosedFn = std::function& group /** Called to create a metadef name. -Most likely should call IExecutionProvider::GenerateMetaDefId. +Most likely should call ModelMetadefIdGenerator.GenerateId. See onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc for example usage. @return The metadef name. diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 0310cc2bc8f26..5f4e2e62f063e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -129,7 +129,7 @@ static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevic QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options) - : IExecutionProvider{onnxruntime::kQnnExecutionProvider, true} { + : IExecutionProvider{onnxruntime::kQnnExecutionProvider} { if (session_options) { disable_cpu_ep_fallback_ = session_options->config_options.GetConfigOrDefault( kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; @@ -472,7 +472,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const auto gen_metadef_name = [&]() { uint64_t model_hash; - int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); return MakeString(QNN, "_", model_hash, "_", metadef_id); }; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 3f75be0efebcd..09bcb24db4dc2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -5,6 +5,7 @@ #include "core/framework/execution_provider.h" #include "core/framework/session_options.h" +#include "core/framework/model_metadef_id_generator.h" #include "core/graph/model.h" #include #include "core/providers/qnn/builder/qnn_backend_manager.h" @@ -71,6 +72,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool qnn_context_embed_mode_ = true; int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; + ModelMetadefIdGenerator metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 53ba4874c643c..1e3a528d87721 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -142,7 +142,7 @@ struct KernelDefBuilder; struct KernelRegistry; struct Function; struct Graph; -struct GraphViewer; +class GraphViewer; enum class DataLayout; struct Model; struct Path; @@ -157,6 +157,7 @@ struct Tensor; struct SparseTensor; class TensorSeq; class SessionState; +class ModelMetadefIdGenerator; class If; class Loop; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index e1d0e310425c5..6dbe103791e43 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -329,10 +329,6 @@ common::Status IExecutionProvider::Compile(const std::vector& return g_host->IExecutionProvider__Compile(this, fused_nodes_and_graphs, node_compute_funcs); } -int IExecutionProvider::GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const { - return g_host->IExecutionProvider__GenerateMetaDefId(this, graph_viewer, model_hash); -} - #ifdef USE_TENSORRT std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) { return g_host->CreateCUDAAllocator(device_id, name); diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 21c14ce784a38..a216b2bfc6d04 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -229,8 +229,6 @@ struct ProviderHost { virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; - virtual int IExecutionProvider__GenerateMetaDefId(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) = 0; - // Status virtual std::string Status__ToString(const Status* p) = 0; @@ -972,6 +970,11 @@ struct ProviderHost { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) virtual Status LoadDynamicLibrary(onnxruntime::PathString library_name) = 0; #endif + + // ModelMetadefIdGenerator + virtual std::unique_ptr ModelMetadefIdGenerator__construct() = 0; + virtual void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) = 0; + virtual int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) = 0; }; #if defined(_MSC_VER) && !defined(__clang__) diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index eaf8ef459cf00..f46c76fd3421b 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -750,7 +750,8 @@ struct Graph final { PROVIDER_DISALLOW_ALL(Graph) }; -struct GraphViewer final { +class GraphViewer final { + public: static void operator delete(void* p) { g_host->GraphViewer__operator_delete(reinterpret_cast(p)); } std::unique_ptr CreateModel(const logging::Logger& logger) const { return g_host->GraphViewer__CreateModel(this, logger); } @@ -1152,6 +1153,13 @@ class TensorSeq final { void Reserve(size_t capacity) { g_host->TensorSeq__Reserve(this, capacity); } }; +class ModelMetadefIdGenerator { + public: + static std::unique_ptr Create() { return g_host->ModelMetadefIdGenerator__construct(); } + static void operator delete(void* p) { g_host->ModelMetadefIdGenerator__operator_delete(reinterpret_cast(p)); } + int GenerateId(const GraphViewer& graph_viewer, HashValue& model_hash) const { return g_host->ModelMetadefIdGenerator__GenerateId(this, graph_viewer, model_hash); } +}; + template <> inline gsl::span Tensor::DataAsSpan() const { return g_host->Tensor__DataAsSpan_int64(this); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 39e5f5be000e5..cdc28846bd12c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1310,7 +1310,7 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh } TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id), true}, info_(info), device_id_(info.device_id) { + : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info), device_id_(info.device_id) { InitProviderOrtApi(); CUDA_CALL_THROW(cudaSetDevice(device_id_)); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h index a8e3ae3ddf6ec..92cce0c203927 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h @@ -497,7 +497,15 @@ void RemoveCachesByType(const std::string& root, std::string file_extension) { } } -// Helper class to generate engine id via model name/model content/env metadata +/** + * + * Helper class to generate engine id via model name/model content/env metadata + * + * + * The TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches + * compiled kernels, so the name must be unique and deterministic across models and sessions. + * + */ HashValue TRTGenerateId(const GraphViewer& graph_viewer) { HashValue model_hash = 0; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index df7871614b267..cfb96af557d35 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -19,7 +19,7 @@ namespace onnxruntime { WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number, const std::string& webnn_power_flags) - : IExecutionProvider{onnxruntime::kWebNNExecutionProvider, true} { + : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { // Create WebNN context and graph builder. const emscripten::val ml = emscripten::val::global("navigator")["ml"]; if (!ml.as()) { @@ -169,7 +169,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view // Assign inputs and outputs to subgraph's meta_def. uint64_t model_hash; - int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); meta_def->name = "WEBNN_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); meta_def->domain = kMSDomain; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index 13a475327dc0c..d9cfa5f17c0d4 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -6,6 +6,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/execution_provider.h" +#include "core/framework/model_metadef_id_generator.h" #include "core/providers/webnn/builders/helper.h" #include @@ -48,5 +49,6 @@ class WebNNExecutionProvider : public IExecutionProvider { DataLayout preferred_layout_; webnn::WebnnDeviceType wnn_device_type_; InlinedHashMap> models_; + ModelMetadefIdGenerator metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index a2a776df439e4..eafbfae6f01e1 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -155,7 +155,7 @@ std::unique_ptr RegisterKernels() { using namespace xnnpack; XnnpackExecutionProvider::XnnpackExecutionProvider(const XnnpackExecutionProviderInfo& info) - : IExecutionProvider{kXnnpackExecutionProvider, true} { + : IExecutionProvider{kXnnpackExecutionProvider} { int xnn_thread_pool_size = info.xnn_thread_pool_size; int ort_thread_pool_size = info.session_options ? info.session_options->intra_op_param.thread_pool_size : 1; bool allow_intra_op_spinning = (info.session_options == nullptr) || diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index f48110aa7ee5b..2e445e4982d24 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -30,6 +30,7 @@ #include "core/framework/sparse_utils.h" #include "core/graph/graph_proto_serializer.h" #include "core/framework/murmurhash3.h" +#include "core/framework/model_metadef_id_generator.h" #include "core/session/onnxruntime_c_api.h" #include "core/common/string_helper.h" @@ -317,10 +318,6 @@ struct ProviderHostImpl : ProviderHost { return p->IExecutionProvider::Compile(fused_nodes_and_graphs, node_compute_funcs); } - int IExecutionProvider__GenerateMetaDefId(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) override { - return p->IExecutionProvider::GenerateMetaDefId(graph_viewer, model_hash); - } - // Status (direct) std::string Status__ToString(const Status* p) override { return p->Status::ToString(); } @@ -1083,6 +1080,11 @@ struct ProviderHostImpl : ProviderHost { void TensorSeq__Add(TensorSeq* p, Tensor&& tensor) override { p->Add(std::move(tensor)); } void TensorSeq__Reserve(TensorSeq* p, size_t capacity) override { p->Reserve(capacity); } + // ModelMetadefIdGenerator(wrapped) + std::unique_ptr ModelMetadefIdGenerator__construct() override { return std::make_unique(); } + void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) override { delete p; } + int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) override { return p->GenerateId(graph_viewer, model_hash); } + #if defined(ENABLE_TRAINING) && defined(ORT_USE_NCCL) training::DistributedRunContext& GetDistributedRunContextInstance() override { return training::DistributedRunContext::GetInstance(); } #endif diff --git a/onnxruntime/test/framework/execution_provider_test.cc b/onnxruntime/test/framework/execution_provider_test.cc index 5a7351a766fa3..390fda7bfc5ad 100644 --- a/onnxruntime/test/framework/execution_provider_test.cc +++ b/onnxruntime/test/framework/execution_provider_test.cc @@ -6,6 +6,7 @@ #include "test_utils.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" +#include "core/framework/model_metadef_id_generator.h" #include "gtest/gtest.h" @@ -18,11 +19,14 @@ class TestEP : public IExecutionProvider { static constexpr const char* kEPType = "TestEP"; public: - TestEP() : IExecutionProvider{kEPType, true} {} + TestEP() : IExecutionProvider{kEPType} {} int GetId(const GraphViewer& viewer, HashValue& model_hash) { - return GenerateMetaDefId(viewer, model_hash); + return metadef_id_generator_.GenerateId(viewer, model_hash); } + + private: + ModelMetadefIdGenerator metadef_id_generator_; }; TEST(ExecutionProviderTest, MetadefIdGeneratorUsingModelPath) { diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 19253e1a5bd2c..6fe0754db40d3 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -82,7 +82,7 @@ class TestEP : public IExecutionProvider { TestTuningContext tuning_ctx_{this}; public: - TestEP() : IExecutionProvider{kEPType, true} {} + TestEP() : IExecutionProvider{kEPType} {} ITuningContext* GetTuningContext() const override { return const_cast(&tuning_ctx_); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index 957443c23e7c3..0167f7a7718b1 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -85,7 +85,7 @@ constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP"; InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set& ops, const std::unordered_set& stop_ops, DataLayout preferred_layout) - : IExecutionProvider{utils::kInternalTestingExecutionProvider, true}, + : IExecutionProvider{utils::kInternalTestingExecutionProvider}, ep_name_{INTERNAL_TESTING_EP}, ops_{ops}, stop_ops_{stop_ops}, @@ -212,7 +212,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& // create functor to generate a guaranteed unique metadef id auto generate_metadef_name = [this, &graph_viewer]() { HashValue model_hash; - int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); return ep_name_ + "_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); }; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index 6103352627667..6615eb82f2b05 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -4,6 +4,7 @@ #pragma once #include #include "core/framework/execution_provider.h" +#include "core/framework/model_metadef_id_generator.h" namespace onnxruntime { namespace internal_testing_ep { @@ -82,6 +83,7 @@ class InternalTestingExecutionProvider : public IExecutionProvider { // per-instance kernel registry so tests using static kernels don't clash. // shared_ptr as required by IExecutionProvider::GetKernelRegistry std::shared_ptr kernel_registry_; + ModelMetadefIdGenerator metadef_id_generator_; }; } // namespace internal_testing_ep From d7ff81dfb77989a8ce975db29457e5cdfc00f9e3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 26 Jan 2024 10:34:43 -0800 Subject: [PATCH 45/45] [CUDA] support user_compute_stream in python API (#19229) ### Description It is an important feature to pass user cuda stream to avoid synchronization in python API. Here we allow user to pass cuda stream for CUDA provider. Note that TRT or ROCm provider need similar change, which are not included in this pull request. Note that we will set `has_user_compute_stream` automatically based on whether there is cuda stream passed, so setting `has_user_compute_stream` through python API has no effect. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/19094 --- .../cuda/cuda_execution_provider_info.cc | 16 ++++++++++++++++ .../test/python/onnxruntime_test_python.py | 19 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index daa3b5ff3d72f..7b507296d5982 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -16,6 +16,7 @@ namespace cuda { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; constexpr const char* kMemLimit = "gpu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search"; @@ -51,6 +52,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P void* alloc = nullptr; void* free = nullptr; void* empty_cache = nullptr; + void* user_compute_stream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -66,6 +68,14 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P return Status::OK(); }) .AddAssignmentToReference(cuda::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + cuda::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return Status::OK(); + }) .AddValueParser( cuda::provider_option_names::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { @@ -126,6 +136,10 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P CUDAExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; info.external_allocator_info = alloc_info; + + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); + return info; } @@ -133,6 +147,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution const ProviderOptions options{ {cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, @@ -160,6 +175,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid const ProviderOptions options{ {cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {cuda::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, {cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)}, diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 8c23286e45445..e210917e7ad9a 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -434,6 +434,25 @@ def test_get_and_set_option_with_values(option_name, option_values): self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_alloc"], "0") self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_free"], "0") self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_empty_cache"], "0") + + option["user_compute_stream"] = "0" + sess.set_providers(["CUDAExecutionProvider"], [option]) + options = sess.get_provider_options() + self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], "0") + + try: + import torch + + if torch.cuda.is_available(): + s = torch.cuda.Stream() + option["user_compute_stream"] = str(s.cuda_stream) + sess.set_providers(["CUDAExecutionProvider"], [option]) + options = sess.get_provider_options() + self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], str(s.cuda_stream)) + self.assertEqual(options["CUDAExecutionProvider"]["has_user_compute_stream"], "1") + except ImportError: + print("torch is not installed, skip testing setting user_compute_stream from torch cuda stream") + # # Note: Tests that throw an exception leave an empty session due to how set_providers currently works, # so run them last. Each set_providers call will attempt to re-create a session, so it's