diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 6c3f2eb0fbbe1..725c40c2ded53 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -24,7 +24,7 @@ jobs: name: Download prebuilt ONNX Runtime archive from build.rs runs-on: ubuntu-latest env: - ORT_RUST_STRATEGY=download + ORT_RUST_STRATEGY: download steps: - uses: actions/checkout@v4 - uses: ./.github/actions/rust-toolchain-setup diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 5a016717f7d1e..137ea8a50c011 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "3abf3298b6b43acc8556b1342ffb6de4a85fb30f", + "commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" @@ -126,7 +126,7 @@ "component": { "type": "git", "git": { - "commitHash": "b3a9ba2b8e975550799838332803d468797ae2e1", + "commitHash": "530d5c8c84abd2a46f38583ee817743c9b3a42b4", "repositoryUrl": "https://github.com/google/googletest.git" }, "comments": "googletest" diff --git a/cmake/deps.txt b/cmake/deps.txt index 8a9ccef6f8181..ff07803013071 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/3abf3298b6b43acc8556b1342ffb6de4a85fb30f.zip;d6da50a47c1268b5d6d5405b7fc21258ccd84d31 +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -27,7 +27,7 @@ fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 -googletest;https://github.com/google/googletest/archive/b3a9ba2b8e975550799838332803d468797ae2e1.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc +googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 0fa5163dc06bf..78f63227c8392 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -47,8 +47,8 @@ if (onnxruntime_BUILD_UNIT_TESTS) FetchContent_Declare( googletest URL ${DEP_URL_googletest} - FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest URL_HASH SHA1=${DEP_SHA1_googletest} + FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest ) endif() @@ -124,7 +124,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) if(protoc_binary_SOURCE_DIR) message("Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") @@ -140,7 +140,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) if(protoc_binary_SOURCE_DIR) message("Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin") FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) @@ -281,7 +281,7 @@ if ((CPUINFO_SUPPORTED OR onnxruntime_USE_XNNPACK) AND NOT ANDROID) pytorch_clog URL ${DEP_URL_pytorch_cpuinfo} URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - SOURCE_SUBDIR deps/clog + SOURCE_SUBDIR deps/clog ) set(ONNXRUNTIME_CLOG_PROJ pytorch_clog) set(ONNXRUNTIME_CLOG_TARGET_NAME clog) diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 7ac4a82c89a76..0951c2d02664d 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -15,16 +15,10 @@ "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" ) - list(REMOVE_ITEM onnxruntime_providers_vitisai_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc") source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - onnxruntime_add_shared_library(onnxruntime_vitisai_ep ${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc) - onnxruntime_add_include_to_target(onnxruntime_vitisai_ep onnxruntime_common) - target_include_directories(onnxruntime_vitisai_ep PRIVATE "${ONNXRUNTIME_ROOT}" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include") - target_link_libraries(onnxruntime_providers_vitisai PUBLIC onnxruntime_vitisai_ep PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json ) - target_compile_definitions(onnxruntime_vitisai_ep - PRIVATE "-DONNXRUNTIME_VITISAI_EP_STUB=1" "-DONNXRUNTIME_VITISAI_EP_EXPORT_DLL=1") + target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) if(NOT MSVC) target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) endif(NOT MSVC) @@ -49,4 +43,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index df62199dc2b42..7c8c70f913dca 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1373,56 +1373,55 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(compare_two_sessions PRIVATE ${GETOPT_LIB_WIDE} tdh Advapi32) endif() - file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS - "${TEST_SRC_DIR}/mlas/unittest/*.h" - "${TEST_SRC_DIR}/mlas/unittest/*.cpp" - ) - onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) - if(MSVC) - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" - "$<$>:/wd26409>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" - "$<$>:/utf-8>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" - "$<$>:/wd6326>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" - "$<$>:/wd26426>") - endif() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set_target_properties(onnxruntime_mlas_test PROPERTIES - XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" + if(NOT onnxruntime_target_platform STREQUAL "ARM64EC") + file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS + "${TEST_SRC_DIR}/mlas/unittest/*.h" + "${TEST_SRC_DIR}/mlas/unittest/*.cpp" ) - endif() - target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} - ${CMAKE_CURRENT_BINARY_DIR}) - target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo) - endif() - if(NOT WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) - endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs}) - endif() - - if(WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32) - endif() - if (onnxruntime_LINK_LIBATOMIC) - target_link_libraries(onnxruntime_mlas_test PRIVATE atomic) - endif() - target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads) - - set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) + if(MSVC) + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" + "$<$>:/wd26409>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + "$<$>:/utf-8>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" + "$<$>:/wd6326>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" + "$<$>:/wd26426>") endif() - endif() - + if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + set_target_properties(onnxruntime_mlas_test PROPERTIES + XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" + ) + endif() + target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} + ${CMAKE_CURRENT_BINARY_DIR}) + target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) + if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo) + endif() + if(NOT WIN32) + target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs}) + endif() + if(WIN32) + target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32) + endif() + if (onnxruntime_LINK_LIBATOMIC) + target_link_libraries(onnxruntime_mlas_test PRIVATE atomic) + endif() + target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads) + set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest") + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") + else() + set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + endif() + endif() +endif() # Training API Tests # Disabling training_api_test_trainer. CXXOPT generates a ton of warnings because of which nuget pipeline is failing. # TODO(askhade): Fix the warnings. diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index adf0b1b2964b5..ae5bf68483b46 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -115,7 +115,7 @@ export class ProgramManager { inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; }); let outputShapes = ''; - inputTensorViews.forEach((value, i) => { + outputTensorViews.forEach((value, i) => { outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; }); // eslint-disable-next-line no-console diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 59bdd43ec997e..b629c8eff9097 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -2,6 +2,10 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #include "vaip/global_api.h" + +#include +#include + #include "./vai_assert.h" #include "core/common/exceptions.h" #include "core/common/logging/logging.h" @@ -10,10 +14,10 @@ #include "core/graph/model.h" #include "core/session/ort_env.h" +#include "core/session/onnxruntime_cxx_api.h" -#include +#include -#include "core/session/onnxruntime_cxx_api.h" #include "vaip/dll_safe.h" #include "vaip/vaip_ort_api.h" #include "vaip/graph.h" @@ -24,28 +28,107 @@ #include "./attr_proto.h" #include "./register_xir_ops.h" -#include "onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h" - #include "onnxruntime_config.h" #include "version_info.h" // version_info.hpp.in using namespace onnxruntime; +using json = nlohmann::json; + +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + vaip_core::OrtApiForVaip* create_org_api_hook(); +struct OrtVitisAIEpAPI { + void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector& ret_domain); + std::vector>* (*compile_onnx_model_3)(const std::string& model_path, + const onnxruntime::Graph& graph, + const char* json_config); + std::vector>* (*compile_onnx_model_with_options)( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); + void Ensure() { + if (handle_) return; + auto full_path = Env::Default().GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); + ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true, &handle_)); + ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary( + handle_, "initialize_onnxruntime_vitisai_ep", reinterpret_cast(&initialize_onnxruntime_vitisai_ep))); + auto status1 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", + reinterpret_cast(&compile_onnx_model_with_options)); + auto status2 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", + reinterpret_cast(&compile_onnx_model_3)); + if (!status1.IsOK() && !status2.IsOK()) { + ::onnxruntime::LogRuntimeError(0, status1, __FILE__, static_cast(__FUNCTION__), __LINE__); + ORT_THROW(status1); + } + } + + private: + void* handle_{}; +}; + +static OrtVitisAIEpAPI s_library_vitisaiep; +static std::string config_to_json_str(const onnxruntime::ProviderOptions& config) { + auto iter = config.find("config_file"); + if (iter == config.end()) { + std::cerr << "Error: Key 'config_file' not found in config" << std::endl; + return ""; + } + const auto& filename = config.at("config_file"); + std::ifstream f(filename); + if (!f.is_open()) { + std::cerr << "Error: Failed to open file: " << filename << std::endl; + return ""; + } + nlohmann::json data; + try { + data = nlohmann::json::parse(f); + } catch (const std::exception& e) { + std::cerr << "Error: Failed to parse JSON from file: " << filename << ", Reason: " << e.what() << std::endl; + return ""; + } + for (const auto& entry : config) { + data[entry.first] = entry.second; + } + try { + return data.dump(); + } catch (const std::exception& e) { + std::cerr << "Error: Failed to convert JSON data to string, Reason: " << e.what() << std::endl; + return ""; + } +} +vaip_core::DllSafe>> compile_onnx_model_with_options( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options) { + if (s_library_vitisaiep.compile_onnx_model_with_options) { + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph, options)); + } else { + auto json_str = config_to_json_str(options); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph, json_str.c_str())); + } +} std::vector initialize_vitisai_ep() { + s_library_vitisaiep.Ensure(); Status status = Status::OK(); try { - OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, "onnxruntime-vitisai-ep"}; + OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, + "onnxruntime-vitisai-ep"}; std::ignore = OrtEnv::GetInstance(lm_info, status); } catch (onnxruntime::OnnxRuntimeException& /*e*/) { } auto domains = std::vector(); domains.reserve(100); - onnxruntime_vitisai_ep::initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); - auto& domainToVersionRangeInstance = - ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); - if (domainToVersionRangeInstance.Map().find("com.xilinx") == - domainToVersionRangeInstance.Map().end()) { + s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); + auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); + if (domainToVersionRangeInstance.Map().find("com.xilinx") == domainToVersionRangeInstance.Map().end()) { vaip::register_xir_ops(domains); } @@ -68,17 +151,14 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.model_delete = [](Model* model) { delete model; }; the_global_api.model_clone = [](const Model& model) -> Model* { auto& logger = logging::LoggingManager::DefaultLogger(); - auto model_proto = - const_cast(model).ToProto(); + auto model_proto = const_cast(model).ToProto(); auto file_path = model.ModelPath().ToPathString(); auto ret = std::make_unique(std::move(model_proto), file_path, nullptr, logger); auto status = ret->MainGraph().Resolve(); vai_assert(status.IsOK(), status.ErrorMessage()); return ret.release(); }; - the_global_api.model_set_meta_data = [](Model& model, const std::string& key, - const std::string& value) - -> void { + the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) -> void { const_cast(model.MetaData())[key] = value; }; the_global_api.model_get_meta_data = [](const Model& model, @@ -97,14 +177,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return m.find(key) != m.end() ? 1 : 0; }; - the_global_api.model_main_graph = [](Model& model) -> Graph& { - return model.MainGraph(); - }; - the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { - return graph.GetModel(); - }; - the_global_api.graph_get_inputs_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.model_main_graph = [](Model& model) -> Graph& { return model.MainGraph(); }; + the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { return graph.GetModel(); }; + the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { auto ret = std::vector(); auto inputs = graph.GetInputs(); for (auto input : inputs) { @@ -113,47 +188,35 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } return vaip_core::DllSafe(std::move(ret)); }; - the_global_api.graph_get_outputs_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { return vaip_core::DllSafe(graph.GetOutputs()); }; - the_global_api.graph_set_outputs = - [](Graph& graph, gsl::span outputs) -> void { + the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) -> void { return graph.SetOutputs(outputs); }; - the_global_api.graph_get_node_arg = - [](const Graph& graph, const std::string& name) -> const NodeArg* { + the_global_api.graph_get_node_arg = [](const Graph& graph, const std::string& name) -> const NodeArg* { return graph.GetNodeArg(name); }; the_global_api.graph_producer_node = [](const Graph& graph, const std::string& name) -> const Node* { return graph.GetProducerNode(name); }; - the_global_api.graph_get_node = [](const Graph& graph, - size_t index) -> const Node* { - return graph.GetNode(index); - }; + the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { return graph.GetNode(index); }; the_global_api.graph_save = vaip::graph_save; the_global_api.graph_fuse = vaip::graph_fuse; the_global_api.graph_remove_node = vaip::graph_remove_node; - the_global_api.graph_add_node = - [](Graph& graph, const std::string& name, const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - vaip_core::NodeAttributes& attributes, - const std::string& domain) -> Node& { - return vaip::graph_add_node( - graph, name, op_type, description, input_args, output_args, - std::move(reinterpret_cast(attributes)), - domain); - }; - - the_global_api.graph_get_all_initialized_tensors = - [](const Graph& graph) -> const InitializedTensorSet& { + the_global_api.graph_add_node = [](Graph& graph, const std::string& name, const std::string& op_type, + const std::string& description, const std::vector& input_args, + const std::vector& output_args, + vaip_core::NodeAttributes& attributes, const std::string& domain) -> Node& { + return vaip::graph_add_node(graph, name, op_type, description, input_args, output_args, + std::move(reinterpret_cast(attributes)), domain); + }; + + the_global_api.graph_get_all_initialized_tensors = [](const Graph& graph) -> const InitializedTensorSet& { return graph.GetAllInitializedTensors(); }; @@ -166,66 +229,46 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; the_global_api.graph_get_consumer_nodes_unsafe = - [](const Graph& graph, - const std::string& node_arg_name) -> vaip_core::DllSafe> { + [](const Graph& graph, const std::string& node_arg_name) -> vaip_core::DllSafe> { return vaip_core::DllSafe(graph.GetConsumerNodes(node_arg_name)); }; - the_global_api.graph_nodes_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { auto& node_refererence = graph.Nodes(); - std::vector nodes((size_t)graph.NumberOfNodes(), nullptr); - std::transform(node_refererence.begin(), node_refererence.end(), - nodes.begin(), [](const Node& n) { return &n; }); + std::vector nodes(static_cast(graph.NumberOfNodes()), nullptr); + std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); return vaip_core::DllSafe(std::move(nodes)); }; - the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { - return graph.Name(); + the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; + the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& stop) { + graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); }; - the_global_api.graph_reverse_dfs_from = - [](const Graph& graph, gsl::span from, - const std::function& enter, - const std::function& leave, - const std::function& stop) { - graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); - }; // node the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; - the_global_api.node_op_type = [](const Node& node) -> const std::string& { - return node.OpType(); - }; - the_global_api.node_op_domain = [](const Node& node) -> const std::string& { - return node.Domain(); - }; - the_global_api.node_get_index = [](const Node& node) -> size_t { - return (size_t)node.Index(); - }; - the_global_api.node_get_name = [](const Node& node) -> const std::string& { - return node.Name(); - }; - the_global_api.node_description = [](const Node& node) -> const std::string& { - return node.Description(); - }; + the_global_api.node_op_type = [](const Node& node) -> const std::string& { return node.OpType(); }; + the_global_api.node_op_domain = [](const Node& node) -> const std::string& { return node.Domain(); }; + the_global_api.node_get_index = [](const Node& node) -> size_t { return static_cast(node.Index()); }; + the_global_api.node_get_name = [](const Node& node) -> const std::string& { return node.Name(); }; + the_global_api.node_description = [](const Node& node) -> const std::string& { return node.Description(); }; - the_global_api.node_get_attributes = - [](Node& node) -> vaip_core::NodeAttributes& { - return reinterpret_cast( - node.GetMutableAttributes()); + the_global_api.node_get_attributes = [](Node& node) -> vaip_core::NodeAttributes& { + return reinterpret_cast(node.GetMutableAttributes()); }; the_global_api.node_type_is_fused = [](const Node& node) { return node.NodeType() == onnxruntime::Node::Type::Fused; }; - the_global_api.node_get_function_body = - [](const Node& node) -> const onnxruntime::Graph& { + the_global_api.node_get_function_body = [](const Node& node) -> const onnxruntime::Graph& { assert(node.GetFunctionBody() != nullptr); return node.GetFunctionBody()->Body(); }; // node_arg - the_global_api.node_arg_get_name_unsafe = - [](const NodeArg& node_arg) -> const std::string& { + the_global_api.node_arg_get_name_unsafe = [](const NodeArg& node_arg) -> const std::string& { return node_arg.Name(); }; the_global_api.node_arg_clone = vaip::node_arg_clone; @@ -236,8 +279,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.node_arg_set_shape_i64 = vaip::node_arg_set_shape_i64; the_global_api.node_arg_get_denotation_unsafe = vaip::node_arg_get_denotation; the_global_api.node_arg_set_denotation = vaip::node_arg_set_denotation; - the_global_api.node_arg_get_const_data_as_tensor = - vaip::node_arg_get_const_data_as_tensor; + the_global_api.node_arg_get_const_data_as_tensor = vaip::node_arg_get_const_data_as_tensor; the_global_api.node_arg_get_element_type = vaip::node_arg_get_element_type; the_global_api.node_arg_set_element_type = [](NodeArg& node_arg, int type) { @@ -299,16 +341,13 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; /// attr proto the_global_api.attr_proto_delete = [](onnx::AttributeProto* v) { delete v; }; - the_global_api.attr_proto_clone = - [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { + the_global_api.attr_proto_clone = [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { return new onnx::AttributeProto(v); }; - the_global_api.attr_proto_get_name = - [](const onnx::AttributeProto& attr_proto) -> const std::string& { + the_global_api.attr_proto_get_name = [](const onnx::AttributeProto& attr_proto) -> const std::string& { return attr_proto.name(); }; - the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, - const std::string& name) { + the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, const std::string& name) { attr_proto->set_name(name); }; the_global_api.attr_proto_new_int = vaip::attr_proto_new_int; @@ -325,17 +364,14 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.attr_proto_get_ints = vaip::attr_proto_get_ints; the_global_api.attr_proto_get_floats = vaip::attr_proto_get_floats; the_global_api.attr_proto_get_strings = vaip::attr_proto_get_strings; - the_global_api.attr_proto_get_type = - [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; + the_global_api.attr_proto_get_type = [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; /// node attributes the_global_api.node_attributes_new = []() { return reinterpret_cast(new NodeAttributes()); }; - the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, - onnx::AttributeProto&& attr) { - reinterpret_cast(p).insert_or_assign(attr.name(), - std::move(attr)); + the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, onnx::AttributeProto&& attr) { + reinterpret_cast(p).insert_or_assign(attr.name(), std::move(attr)); }; the_global_api.node_attributes_delete = [](vaip_core::NodeAttributes* p) { delete reinterpret_cast(p); @@ -349,7 +385,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } return &it->second; }; - the_global_api.node_attributes_get_keys = [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { + the_global_api.node_attributes_get_keys = + [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { auto ret = std::vector(); auto& attr = reinterpret_cast(p); ret.reserve(attr.size()); @@ -359,34 +396,29 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return vaip_core::DllSafe(std::move(ret)); }; /// tensor proto - the_global_api.tensor_proto_get_shape_unsafe = [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { + the_global_api.tensor_proto_get_shape_unsafe = + [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { return vaip_core::DllSafe>(vaip::tensor_proto_get_shape(t)); }; - the_global_api.tensor_proto_data_type = - [](const onnx::TensorProto& t) -> int { return t.data_type(); }; + the_global_api.tensor_proto_data_type = [](const onnx::TensorProto& t) -> int { return t.data_type(); }; the_global_api.tensor_proto_delete = [](onnx::TensorProto* tp) { delete tp; }; - the_global_api.tensor_proto_new_floats = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{ - vaip::tensor_proto_new_floats(name, shape, data)}; + the_global_api.tensor_proto_new_floats = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { + return new onnx::TensorProto{vaip::tensor_proto_new_floats(name, shape, data)}; }; - the_global_api.tensor_proto_new_i32 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i32 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i32(name, shape, data)}; }; - the_global_api.tensor_proto_new_i64 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i64 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i64(name, shape, data)}; }; - the_global_api.tensor_proto_new_i8 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i8 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i8(name, shape, data)}; }; the_global_api.tensor_proto_raw_data_size = vaip::tensor_proto_raw_data_size; diff --git a/onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h b/onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h deleted file mode 100644 index 82f665429c24c..0000000000000 --- a/onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. -#pragma once -#include -#include -#if defined(_WIN32) -#if ONNXRUNTIME_VITISAI_EP_EXPORT_DLL == 1 -#define ONNXRUNTIME_VITISAI_EP_DLL_SPEC __declspec(dllexport) -#else -#define ONNXRUNTIME_VITISAI_EP_DLL_SPEC __declspec(dllimport) -#endif -#else -#define ONNXRUNTIME_VITISAI_EP_DLL_SPEC __attribute__((visibility("default"))) -#endif - -#ifndef USE_VITISAI -#define USE_VITISAI /* mimic VITISAI EP in ORT */ -#endif - -namespace vaip_core { -class ExecutionProvider; -struct OrtApiForVaip; -template -class DllSafe; -} // namespace vaip_core -namespace onnxruntime { -class Graph; -} -struct OrtCustomOpDomain; -namespace onnxruntime_vitisai_ep { - -ONNXRUNTIME_VITISAI_EP_DLL_SPEC void -initialize_onnxruntime_vitisai_ep(vaip_core::OrtApiForVaip* api, - std::vector& ret_domain); -ONNXRUNTIME_VITISAI_EP_DLL_SPEC -vaip_core::DllSafe>> -compile_onnx_model_3(const std::string& model_path, - const onnxruntime::Graph& graph, const char* json_config); -ONNXRUNTIME_VITISAI_EP_DLL_SPEC -int optimize_onnx_model(const std::filesystem::path& model_path_in, - const std::filesystem::path& model_path_out, - const char* json_config); -} // namespace onnxruntime_vitisai_ep - -extern "C" ONNXRUNTIME_VITISAI_EP_DLL_SPEC const vaip_core::OrtApiForVaip* -get_the_global_api(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 8da3882b5af99..c446ab3aefcc5 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -2,6 +2,16 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once +#include +#include +#include + #include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/provider_options.h" +#include "vaip/my_ort.h" +#include "vaip/dll_safe.h" +#include "vaip/custom_op.h" std::vector initialize_vitisai_ep(); +vaip_core::DllSafe>> compile_onnx_model_with_options( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); diff --git a/onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc b/onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc deleted file mode 100644 index 8244c36f822a4..0000000000000 --- a/onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. -#include "vaip/dll_safe.h" -#include "vaip/vaip_ort_api.h" -#include "vaip/custom_op.h" -#include "onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h" -#include -#include -using namespace std; - -namespace onnxruntime_vitisai_ep { -static void my_abort() { - cerr << "please install VitisAI package." << endl; - abort(); -} -using namespace vaip_core; -void initialize_onnxruntime_vitisai_ep(OrtApiForVaip* /*api*/, std::vector& /*domain*/) { - my_abort(); - return; -} // namespace onnxruntime_vitisai_ep -DllSafe>> -compile_onnx_model_3(const std::string& /*model_path*/, const Graph& /*graph*/, - const char* /*json_config*/) { - if (1) { // suppress dead code warning - my_abort(); - } - return DllSafe>>(); -} - -} // namespace onnxruntime_vitisai_ep diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 32ee6ff652aac..5f20b32cd6dc4 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -15,8 +15,6 @@ #include "core/session/custom_ops.h" #include "core/session/inference_session.h" -#include "onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h" - using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -24,8 +22,7 @@ namespace onnxruntime { constexpr const char* VITISAI = "VITISAI"; static vaip_core::DllSafe>> compile_onnx_model( - const onnxruntime::GraphViewer& graph_viewer, - const logging::Logger& logger, const char* json_config) { + const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { #ifndef _WIN32 auto model_path = graph_viewer.ModelPath().ToPathString(); #else @@ -33,12 +30,13 @@ static vaip_core::DllSafe strconverter; auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); #endif - return onnxruntime_vitisai_ep::compile_onnx_model_3(model_path, graph_viewer.GetGraph(), json_config); + return compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options); } + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { - op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), - reinterpret_cast(&info)); + op_kernel_ = + op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast(&info)); } ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } @@ -55,8 +53,7 @@ struct MyCustomOpKernel : OpKernel { void* op_kernel_; }; -VitisAIExecutionProvider::VitisAIExecutionProvider( - const VitisAIExecutionProviderInfo& info) +VitisAIExecutionProvider::VitisAIExecutionProvider(const ProviderOptions& info) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { custom_op_domains_ = initialize_vitisai_ep(); registry_ = std::make_shared(); @@ -77,7 +74,8 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { } } def_builder.Provider(onnxruntime::kVitisAIExecutionProvider); - KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, + std::unique_ptr& out) -> Status { out = std::make_unique(info, *op); return Status::OK(); }; @@ -89,9 +87,8 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return registry_; } -std::vector> -VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const { +std::vector> VitisAIExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { if (graph.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; @@ -100,9 +97,7 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // Only compiling a model once is currently supported return {}; } - auto opt_str = info_.get_json_config_str(); // String - execution_providers_ = - std::make_unique(compile_onnx_model(graph, *GetLogger(), opt_str)); + execution_providers_ = std::make_unique(compile_onnx_model(graph, *GetLogger(), info_)); auto result = vaip::GetComputeCapabilityOps(graph, execution_providers_.get(), vitisai_optypes_); size_t index = 0u; for (auto& ep : **execution_providers_) { @@ -112,16 +107,14 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, return result; } -common::Status VitisAIExecutionProvider::Compile( - const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { +common::Status VitisAIExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { for (const auto& fused_node_graph : fused_nodes_and_graphs) { NodeComputeInfo compute_info; const onnx::AttributeProto* attr = graph_utils::GetNodeAttribute(fused_node_graph.fused_node, "index"); assert(attr != nullptr); size_t index = (size_t)attr->i(); - compute_info.create_state_func = [this, index](ComputeContext* context, - FunctionState* state) { + compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; return 0; @@ -129,15 +122,11 @@ common::Status VitisAIExecutionProvider::Compile( compute_info.release_state_func = [](FunctionState state) { if (state) { - delete reinterpret_cast( - state); + delete reinterpret_cast(state); } }; - compute_info.compute_func = [](FunctionState state, const OrtApi* api, - OrtKernelContext* context) { - reinterpret_cast( - state) - ->Compute(api, context); + compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + reinterpret_cast(state)->Compute(api, context); return Status::OK(); }; node_compute_funcs.push_back(compute_info); diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 5bdfc8c18fb6d..e86b53339d4d2 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -4,6 +4,10 @@ #pragma once #include +#include +#include +#include +#include #include "core/framework/execution_provider.h" #include "core/framework/customregistry.h" @@ -18,34 +22,19 @@ class ExecutionProvider; } // namespace vaip_core namespace onnxruntime { -// Information needed to construct execution providers. -struct VitisAIExecutionProviderInfo { - VitisAIExecutionProviderInfo(const ProviderOptions& provider_options); - - const char* get_json_config_str() const { - return json_config_.c_str(); - } - - private: - ProviderOptions provider_options_; - const std::string json_config_; -}; - // Logical device representation. class VitisAIExecutionProvider : public IExecutionProvider { public: - explicit VitisAIExecutionProvider(const VitisAIExecutionProviderInfo& info); + explicit VitisAIExecutionProvider(const ProviderOptions& info); ~VitisAIExecutionProvider() = default; - std::vector> - GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; + std::vector> GetCapability(const onnxruntime::GraphViewer& graph, + const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return 0; } - common::Status Compile( - const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) override; + common::Status Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; private: @@ -54,7 +43,7 @@ class VitisAIExecutionProvider : public IExecutionProvider { using my_ep_uptr_t = std::shared_ptr; // we have to hide the implementation by forward declaration. mutable my_ep_uptr_t execution_providers_; - VitisAIExecutionProviderInfo info_; + ProviderOptions info_; std::vector custom_op_domains_; std::shared_ptr registry_; std::set vitisai_optypes_; diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 763a3efd1b35b..4c416124ca8f2 100755 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -3,56 +3,37 @@ #include "vitisai_provider_factory_creator.h" +#include +#include + #include "vaip/global_api.h" #include "./vitisai_execution_provider.h" #include "core/framework/execution_provider.h" #include "core/session/abi_session_options_impl.h" -#include "nlohmann/json.hpp" -#include -#include -#include +#include "core/providers/shared_library/provider_host_api.h" using namespace onnxruntime; -using json = nlohmann::json; namespace onnxruntime { -static std::string ConfigToJsonStr(const std::unordered_map& config) { - const auto& filename = config.at("config_file"); - std::ifstream f(filename); - json data = json::parse(f); - for (const auto& entry : config) { - data[entry.first] = entry.second; - } - return data.dump(); -} - -VitisAIExecutionProviderInfo::VitisAIExecutionProviderInfo(const ProviderOptions& provider_options) : provider_options_(provider_options), json_config_{ConfigToJsonStr(provider_options)} {} - struct VitisAIProviderFactory : IExecutionProviderFactory { - VitisAIProviderFactory(const VitisAIExecutionProviderInfo& info) : info_(info) {} + VitisAIProviderFactory(const ProviderOptions& info) : info_(info) {} ~VitisAIProviderFactory() = default; std::unique_ptr CreateProvider() override; private: - VitisAIExecutionProviderInfo info_; + ProviderOptions info_; }; std::unique_ptr VitisAIProviderFactory::CreateProvider() { return std::make_unique(info_); } -std::shared_ptr -CreateExecutionProviderFactory_VITISAI(const VitisAIExecutionProviderInfo& info) { - initialize_vitisai_ep(); - return std::make_shared(info); -} - -std::shared_ptr VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) { +std::shared_ptr VitisAIProviderFactoryCreator::Create( + const ProviderOptions& provider_options) { initialize_vitisai_ep(); - auto info = VitisAIExecutionProviderInfo{provider_options}; - return std::make_shared(info); + return std::make_shared(provider_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h b/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h index 9e0583275d1b6..9bb7cfa062a0f 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h @@ -9,9 +9,6 @@ #include "core/framework/provider_options.h" namespace onnxruntime { - -struct VitisAIExecutionProviderInfo; - struct VitisAIProviderFactoryCreator { static std::shared_ptr Create(const ProviderOptions& provider_options); }; diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 57a37d92335aa..5f8defe8fcb6b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -41,9 +41,11 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto select_last_index = helper.Get("select_last_index", 0); axis = HandleNegativeAxis(axis, input_rank); + emscripten::val axes = emscripten::val::array(); + axes.call("push", static_cast(axis)); emscripten::val options = emscripten::val::object(); - options.set("axis", static_cast(axis)); + options.set("axes", axes); options.set("keepDimensions", keep_dims == 1); options.set("selectLastIndex", select_last_index == 1); emscripten::val output = emscripten::val::object(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc index 516ac7464345b..d147ffbbd181f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc @@ -19,9 +19,10 @@ common::Status ComputeConvPads(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - std::vector& pads_out) { - const int64_t input_size_y = input_shape[2]; - const int64_t input_size_x = input_shape[3]; + std::vector& pads_out, + bool use_nchw) { + const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1]; + const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2]; const int64_t stride_y = onnx_strides[0]; const int64_t stride_x = onnx_strides[1]; const int64_t dilation_y = onnx_dilations[0]; @@ -53,32 +54,17 @@ common::Status HandleAutoPad(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - AutoPadType& auto_pad_type_out) { - auto_pad_type_out = auto_pad_type; - if (auto_pad_type == AutoPadType::NOTSET && onnx_dilations == std::vector{1, 1}) { - { - std::vector same_upper_pads; - ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, - onnx_pads, onnx_strides, onnx_dilations, - AutoPadType::SAME_UPPER, same_upper_pads)); - if (onnx_pads == same_upper_pads) { - auto_pad_type_out = AutoPadType::SAME_UPPER; - return Status::OK(); - } - } - - { - std::vector same_lower_pads; - ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, - onnx_pads, onnx_strides, onnx_dilations, - AutoPadType::SAME_LOWER, same_lower_pads)); - if (onnx_pads == same_lower_pads) { - auto_pad_type_out = AutoPadType::SAME_LOWER; - return Status::OK(); - } - } + std::vector& pads_out, + bool use_nchw) { + if (AutoPadType::SAME_UPPER == auto_pad_type) { + ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, + onnx_pads, onnx_strides, onnx_dilations, + AutoPadType::SAME_UPPER, pads_out, use_nchw)); + } else { + ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, + onnx_pads, onnx_strides, onnx_dilations, + AutoPadType::SAME_LOWER, pads_out, use_nchw)); } - return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h index 76acbca0536ea..cb7c3c6955664 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h @@ -21,7 +21,8 @@ common::Status HandleAutoPad(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - AutoPadType& auto_pad_type_out) ORT_MUST_USE_RESULT; + std::vector& pads_out, + bool use_nchw) ORT_MUST_USE_RESULT; } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index b37340624f850..df0d54e3fd4b4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -44,7 +44,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, const Node& node, emscripten::val& options, const std::vector& strides, const std::vector& dilations, - const std::vector& pads, + std::vector& pads, const logging::Logger& logger) { NodeAttrHelper helper(node); const auto group = helper.Get("group", static_cast(1)); @@ -55,29 +55,85 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, options.set("dilations", emscripten::val::array(dilations)); options.set("groups", group); // Add Padding. - // Usually using autopadding is more efficient than using explicit padding. - // Try to see if we can map explicit padding to auto padding. std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - helper.Get("pads", std::vector{0, 0, 0, 0}), - helper.Get("strides", std::vector{1, 1}), - helper.Get("dilations", std::vector{1, 1}), - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - options.set("autoPad", emscripten::val("same-lower")); + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + if (node.OpType() == "Conv") { + // Calculate explicit padding for autoPad. + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + std::vector pads_out; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], + helper.Get("pads", std::vector{0, 0, 0, 0}), + helper.Get("strides", std::vector{1, 1}), + helper.Get("dilations", std::vector{1, 1}), + auto_pad_type, + pads_out, + model_builder.GetPreferredLayout() == DataLayout::NCHW)); + std::transform(pads_out.begin(), pads_out.end(), pads.begin(), + [](int64_t pad) -> int32_t { return static_cast(pad); }); + } + } else if (node.OpType() == "ConvTranspose") { + // When the 'output_shape' is specificed, the 'output_padding' values + // in options.outputPadding are ignored. + std::vector dim; + std::vector output_padding{0, 0}; + if (helper.HasAttr("output_shape")) { + // Default value of 'output_shape' will be ignore as we already check if + // it's existed. + dim = helper.Get("output_shape", std::vector{-1, -1}); + // Extract the height and width. + std::vector output_shape; + if (dim.size() == 2) { + output_shape = dim; + } else if (dim.size() == 4) { + output_shape = {dim[2], dim[3]}; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); + } + // Padding values are auto generated. + if (helper.HasAttr("kernel_shape")) { + std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); + std::vector total_padding(2); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + for (size_t i = 0; i < 2; i++) { + // Get the dimensions of H and W. + // For NHWC layout, the dimensions of H and W correspond to index 1 and 2. + // For NCHW layout, the dimensions of H and W correspond to index 2 and 3. + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + total_padding[i] = strides[i] * (narrow(input_shape[i + 1]) - 1) + + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + } else { + ORT_RETURN_IF_NOT(model_builder.GetPreferredLayout() == DataLayout::NCHW, + "WebNN GPU backend preferred layout should be NCHW."); + total_padding[i] = strides[i] * (narrow(input_shape[i + 2]) - 1) + + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + } + } + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + pads[0] = total_padding[0] / 2; + pads[1] = total_padding[0] - pads[0]; + pads[2] = total_padding[1] / 2; + pads[3] = total_padding[1] - pads[2]; + if (AutoPadType::SAME_LOWER == auto_pad_type) { + std::swap(pads[0], pads[1]); + std::swap(pads[2], pads[3]); + } + } + } + options.set("outputSizes", emscripten::val::array(output_shape)); } else { - options.set("autoPad", emscripten::val("same-upper")); + output_padding = helper.Get("output_padding", std::vector{0, 0}); + options.set("outputPadding", emscripten::val::array(output_padding)); } } else { - // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], - // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. - const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; - options.set("padding", emscripten::val::array(padding)); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "conv_op_builder only supports Op Conv and ConvTranspose."); } + // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], + // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. + const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; + options.set("padding", emscripten::val::array(padding)); // Add bias if present. if (input_defs.size() > 2) { @@ -198,17 +254,17 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto strides = helper.Get("strides", std::vector{1, 1}); const auto dilations = helper.Get("dilations", std::vector{1, 1}); auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - const auto& weight = input_defs[1]->Name(); + const auto& weight_name = input_defs[1]->Name(); + emscripten::val options = emscripten::val::object(); + ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); if (op_type == "Conv") { - emscripten::val options = emscripten::val::object(); - ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); int groups = options["groups"].as(); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { bool depthwise = (groups == input_shape[3] && groups != 1); options.set("inputLayout", emscripten::val("nhwc")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight, !depthwise)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise)); if (!depthwise) { options.set("filterLayout", emscripten::val("ohwi")); } else { @@ -219,61 +275,10 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N output = model_builder.GetBuilder().call("conv2d", input, filter, options); } else { - emscripten::val options = emscripten::val::object(); - ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight, false)); - } - - // When the 'output_shape' is specificed, the 'output_padding' values - // in options.outputPadding are ignored. - std::vector dim; - std::vector output_padding{0, 0}; - if (helper.HasAttr("output_shape")) { - // Default value of 'output_shape' will be ignore as we already check if - // it's existed. - dim = helper.Get("output_shape", std::vector{-1, -1}); - // Extract the height and width. - std::vector output_shape; - if (dim.size() == 2) { - output_shape = dim; - } else if (dim.size() == 4) { - output_shape = {dim[2], dim[3]}; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); - } - // Padding values are auto generated. - if (helper.HasAttr("kernel_shape")) { - std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); - std::vector total_padding(2); - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - for (size_t i = 0; i < 2; i++) { - // Get the dimensions of H and W. - // For NHWC layout, the dimensions of H and W correspond to index 1 and 2. - // For NCHW layout, the dimensions of H and W correspond to index 2 and 3. - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { - total_padding[i] = strides[i] * (narrow(input_shape[i + 1]) - 1) + - output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; - } else { - ORT_RETURN_IF_NOT(model_builder.GetPreferredLayout() == DataLayout::NCHW, - "WebNN GPU backend preferred layout should be NCHW."); - total_padding[i] = strides[i] * (narrow(input_shape[i + 2]) - 1) + - output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; - } - } - pads[0] = total_padding[0] - (total_padding[0] / 2); - pads[1] = total_padding[0] / 2; - pads[2] = total_padding[1] - (total_padding[1] / 2); - pads[3] = total_padding[1] / 2; - options.set("padding", emscripten::val::array(pads)); - } - options.set("outputSizes", emscripten::val::array(output_shape)); - } else { - output_padding = helper.Get("output_padding", std::vector{0, 0}); - options.set("outputPadding", emscripten::val::array(output_padding)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false)); } emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name()); output = model_builder.GetBuilder().call("convTranspose2d", input, filter, options); @@ -293,22 +298,39 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - const auto& weight_name = input_defs[1]->Name(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get input's shape."; + return false; + } + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s input dimension: " << input_size + << ". Only conv 2d is supported."; + return false; + } + + std::vector weight_shape; + if (!GetShape(*input_defs[1], weight_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get weight's shape."; + return false; + } + + const auto weight_size = weight_shape.size(); + if (weight_size != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s weight dimension: " << weight_size + << ". Only conv 2d is supported."; + return false; + } + // WebNN CPU backend (XNNPACK) requires the filter operand to be a constant. // https://github.com/google/XNNPACK/blob/master/src/subgraph/convolution-2d.c#L739 - if (device_type == WebnnDeviceType::CPU) { - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); - if (tensor.dims().size() != 4) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size() - << " Only conv 2d is supported."; - return false; - } - } else { - LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; - return false; - } + if (device_type == WebnnDeviceType::CPU && !Contains(initializers, input_defs[1]->Name())) { + LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; + return false; } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index ae7c111c1fe78..739c3b3f38def 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -81,28 +81,26 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto onnx_kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); const auto onnx_strides = helper.Get("strides", std::vector{1, 1}); const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - + auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, onnx_kernel_shape[0], onnx_kernel_shape[1], - onnx_pads, onnx_strides, {1, 1} /* dilations */, - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - options.set("autoPad", "same-lower"); - } else { - options.set("autoPad", "same-upper"); - } - } else { - const std::vector pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], - // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. - const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; - options.set("padding", emscripten::val::array(padding)); + std::vector pads_out; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, onnx_kernel_shape[0], onnx_kernel_shape[1], + onnx_pads, + helper.Get("strides", std::vector{1, 1}), + helper.Get("dilations", std::vector{1, 1}), + auto_pad_type, + pads_out, + model_builder.GetPreferredLayout() == DataLayout::NCHW)); + std::transform(pads_out.begin(), pads_out.end(), pads.begin(), + [](int64_t pad) -> int32_t { return static_cast(pad); }); } + // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], + // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. + const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; + options.set("padding", emscripten::val::array(padding)); const auto ceil_mode = helper.Get("ceil_mode", 0); options.set("roundingType", ceil_mode == 0 ? emscripten::val("floor") diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index a5bcbce89bac6..6827f2c9dfd91 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -85,13 +85,6 @@ struct OrtStatus { #define BACKEND_TVM "" #endif -#if USE_VITISAI -#define BACKEND_VITISAI "-VITISAI" -#include "core/providers/vitisai/vitisai_execution_provider.h" -#else -#define BACKEND_VITISAI "" -#endif - #if USE_OPENBLAS #define BACKEND_OPENBLAS "-OPENBLAS" #else @@ -451,9 +444,6 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(c std::shared_ptr CreateExecutionProviderFactory_Tvm(const tvm::TvmEPOptions& info); std::shared_ptr CreateExecutionProviderFactory_Tvm(const char* params); #endif -std::shared_ptr CreateExecutionProviderFactory_VITISAI(const char* backend_type, int device_id, - const char* export_runtime_module, - const char* load_runtime_module); std::shared_ptr CreateExecutionProviderFactory_ACL(int use_arena); std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.cc b/orttraining/orttraining/core/framework/torch/custom_function_register.cc index 1a51da3daa27f..9ab3fdb0b7c0a 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.cc +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.cc @@ -88,11 +88,14 @@ void OrtTorchFunctionPool::RegisterTorchAutogradFunction( PythonObjectPtr forward(PyObject_GetAttrString(obj, "apply"), PythonObjectDeleter); PythonObjectPtr backward(PyObject_GetAttrString(obj, "backward"), PythonObjectDeleter); + PythonObjectPtr unsafe_forward(PyObject_GetAttrString(obj, "forward"), PythonObjectDeleter); ORT_ENFORCE(forward.get(), "apply attribute not found when registering ", key); ORT_ENFORCE(backward.get(), "backward attribute not found when registering ", key); + ORT_ENFORCE(unsafe_forward.get(), "forward attribute not found when registering ", key); RegisterEntry(mutex_, key, forward.get(), forward_core_pool_); RegisterEntry(mutex_, key, backward.get(), backward_core_pool_); + RegisterEntry(mutex_, key, unsafe_forward.get(), unsafe_forward_core_pool_); } void OrtTorchFunctionPool::RegisterShapeInferenceFunction(const std::string& key, @@ -105,46 +108,27 @@ void OrtTorchFunctionPool::RegisterInputAliasFunction(const std::string& key, RegisterEntry(mutex_, key, obj, input_alias_function_pool_); } -static void RegisterEntry( - std::mutex& mutex, - PyObject* obj, - PythonObjectPtr& storage) { - std::lock_guard lock(mutex); - // Basic checks. - ORT_ENFORCE(obj, "Cannot register NULL PyObject*."); - - // Skip registration if storage already stores a Python object. - if (storage.get() != nullptr) { - return; - } - - // Own the Python object. - Py_INCREF(obj); - PythonObjectPtr ptr(obj, PythonObjectDeleter); - - // If an obj has been registered, this old ownership is automatically released - // after this move-assignment. Then, the "storage" owns the new object. - storage = std::move(ptr); +void OrtTorchFunctionPool::RegisterForwardRunner(size_t function_address) { + void* p_forward_runner_func = reinterpret_cast(function_address); + forward_runner_ = reinterpret_cast(p_forward_runner_func); } -void OrtTorchFunctionPool::RegisterForwardRunner(PyObject* obj) { - RegisterEntry(mutex_, obj, forward_runner_); +void OrtTorchFunctionPool::RegisterBackwardRunner(size_t function_address) { + void* p_backward_runner_func = reinterpret_cast(function_address); + backward_runner_ = reinterpret_cast(p_backward_runner_func); } -void OrtTorchFunctionPool::RegisterBackwardRunner(PyObject* obj) { - RegisterEntry(mutex_, obj, backward_runner_); -} +CustomFunctionRunnerType OrtTorchFunctionPool::GetForwardRunner() { + ORT_ENFORCE(forward_runner_, + "Forward runner cannot be NULL. Did you forget to register it by calling RegisterForwardRunner(...)?"); -PyObject* OrtTorchFunctionPool::GetForwardRunner() { - std::lock_guard lock(mutex_); - ORT_ENFORCE(forward_runner_.get(), "Forward runner cannot be NULL. Do you forget register it by calling RegisterForwardRunner(...)?"); - return forward_runner_.get(); + return forward_runner_; } -PyObject* OrtTorchFunctionPool::GetBackwardRunner() { - std::lock_guard lock(mutex_); - ORT_ENFORCE(backward_runner_.get(), "backward runner cannot be NULL. Do you forget register it by calling RegisterBackwardRunner(...)?"); - return backward_runner_.get(); +CustomFunctionRunnerType OrtTorchFunctionPool::GetBackwardRunner() { + ORT_ENFORCE(backward_runner_, + "backward runner cannot be NULL. Did you forget to register it by calling RegisterBackwardRunner(...)?"); + return backward_runner_; } PyObject* OrtTorchFunctionPool::GetForwardCore(const std::string& key) { @@ -163,6 +147,14 @@ PyObject* OrtTorchFunctionPool::GetBackwardCore(const std::string& key) { return iter->second.get(); } +PyObject* OrtTorchFunctionPool::GetUnsafeForwardCore(const std::string& key) { + ORT_ENFORCE(!key.empty(), "Cannot be empty string."); + std::lock_guard lock(mutex_); + auto iter = unsafe_forward_core_pool_.find(key); + ORT_ENFORCE(iter != unsafe_forward_core_pool_.end(), "No unsafe forward registered for ", key); + return iter->second.get(); +} + std::optional OrtTorchFunctionPool::TryGettingShapeInferenceFunction(const std::string& key) { ORT_ENFORCE(!key.empty(), "Cannot be empty string."); std::lock_guard lock(mutex_); @@ -201,10 +193,9 @@ int64_t OrtTorchFunctionPool::RegisterContext(PyObject* autograd_context) { autograd_context, "autograd_context_register"); ORT_ENFORCE(autograd_context, "Cannot register NULL autograd context."); - Py_INCREF(autograd_context); func_context_pool_.insert({index_, PythonObjectPtr(autograd_context, PythonObjectDeleter)}); - // We don't need increase the context refcnt because PyTorch already did it during .apply(). + return index_; } @@ -227,14 +218,13 @@ PyObject* OrtTorchFunctionPool::GetContext(int64_t context_index) { } void OrtTorchFunctionPool::UnRegisterGlobalFunctions() { - forward_runner_.reset(); - backward_runner_.reset(); func_context_pool_.clear(); } void OrtTorchFunctionPool::UnRegisterModelSpecificFunctions() { forward_core_pool_.clear(); backward_core_pool_.clear(); + unsafe_forward_core_pool_.clear(); shape_inference_function_pool_.clear(); input_alias_function_pool_.clear(); miscellaneous_const_input_pool_.clear(); diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.h b/orttraining/orttraining/core/framework/torch/custom_function_register.h index d51cc7dadc1af..67a991ea2cce3 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.h +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.h @@ -13,6 +13,16 @@ namespace onnxruntime { namespace language_interop_ops { namespace torch { +typedef std::vector (*CustomFunctionRunnerType)(const char* func_name_char, + void* callback, + const std::vector& requires_grads, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& tensor_args); + class OrtTorchFunctionPool final { public: static OrtTorchFunctionPool& GetInstance() { @@ -34,6 +44,9 @@ class OrtTorchFunctionPool final { // 2. Caller of GetBackwardCore should not decrease the reference count of the returned object. PyObject* GetBackwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOpGrad. + // Return a borrowed reference to the stored Python function running in safe mode. + PyObject* GetUnsafeForwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOp. + // Shape inference function is used to infer output shape of a PythonOp. void RegisterShapeInferenceFunction(const std::string& key, PyObject* obj); // Return a borrowed reference to the stored Python function, if it exists; otherwise, return nullptr. @@ -67,15 +80,15 @@ class OrtTorchFunctionPool final { // ForwardRunner/BackwardRunner are "glue" codes written in Python that interacting // with C++ kernels during Python function invoking. // This function creates new ownership to "obj". - void RegisterForwardRunner(PyObject* obj); + void RegisterForwardRunner(size_t function_address); // This function creates new ownership to "obj". - void RegisterBackwardRunner(PyObject* obj); - // Return a borrowed reference to a Python function, which + void RegisterBackwardRunner(size_t function_address); + // Return a borrowed reference to a c++ function, which // is responsible for executing autograd.Function.apply. - PyObject* GetForwardRunner(); - // Return a borrowed reference to a Python function, which + CustomFunctionRunnerType GetForwardRunner(); + // Return a borrowed reference to a c++ function, which // is responsible for executing autograd.Function.apply. - PyObject* GetBackwardRunner(); + CustomFunctionRunnerType GetBackwardRunner(); // The reason we provide this unregister api is: // A static OrtTorchFunctionPool instance will be destructed after @@ -97,11 +110,12 @@ class OrtTorchFunctionPool final { void UnRegisterGlobalFunctions(); void UnRegisterModelSpecificFunctions(); - PythonObjectPtr forward_runner_; - PythonObjectPtr backward_runner_; + CustomFunctionRunnerType forward_runner_; + CustomFunctionRunnerType backward_runner_; std::unordered_map forward_core_pool_; std::unordered_map backward_core_pool_; + std::unordered_map unsafe_forward_core_pool_; std::unordered_map shape_inference_function_pool_; std::unordered_map input_alias_function_pool_; diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.cc b/orttraining/orttraining/core/framework/torch/torch_proxy.cc index f36f913366a37..1cd01ae16deea 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.cc +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.cc @@ -12,7 +12,10 @@ namespace onnxruntime::language_interop_ops::torch { -void PythonObjectDeleter(PyObject* ptr) { Py_XDECREF(ptr); }; +void PythonObjectDeleter(PyObject* ptr) { + GilGuard gil; + Py_XDECREF(ptr); +} PyObject* Ort_PyTuple_New(const size_t len, const std::string& log_tag) { PyObject* item = PyTuple_New(len); @@ -20,34 +23,11 @@ PyObject* Ort_PyTuple_New(const size_t len, const std::string& log_tag) { return item; } -void Ort_PyTuple_SetItem_Incref(PyObject* py_tuple, size_t index, PyObject* item, const std::string& log_tag) { - RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); - Py_INCREF(item); - PyTuple_SetItem(py_tuple, index, item); -} - void Ort_PyTuple_SetItem_NoIncref(PyObject* py_tuple, size_t index, PyObject* item, const std::string& log_tag) { RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); PyTuple_SetItem(py_tuple, index, item); } -PyObject* Ort_PyList_New(const size_t len, const std::string& log_tag) { - PyObject* item = PyList_New(len); - RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); - return item; -} - -void Ort_PyList_SetItem_Incref(PyObject* py_list, size_t index, PyObject* item, const std::string& log_tag) { - RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); - Py_INCREF(item); - PyList_SetItem(py_list, index, item); -} - -void Ort_PyList_SetItem_NoIncref(PyObject* py_list, size_t index, PyObject* item, const std::string& log_tag) { - RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); - PyList_SetItem(py_list, index, item); -} - void CheckArguments( const size_t len, const std::vector& requires_grads, @@ -92,87 +72,51 @@ void CheckArguments( // len: the number of input arguments. // tensor_indices: if tensor_indices[i] is j, // then the j-th input argument should be a tensor. -PyObject* CreateTensorFlags( - const size_t len, - const std::vector& tensor_indices) { - PyObject* flags = Ort_PyList_New(len, "tensor_flags_list"); - - // First we fill the list with 0. Later we will - // assign 1's to tensors' corresponding positions. - for (size_t i = 0; i < len; ++i) { - PyObject* zero = PyLong_FromLong(0); - Ort_PyList_SetItem_NoIncref(flags, i, zero, std::to_string(__LINE__)); - } - +std::vector CreateTensorFlags(const size_t len, const std::vector& tensor_indices) { + std::vector flags(len, 0); for (const auto i : tensor_indices) { - PyObject* one = PyLong_FromLong(1); - Ort_PyList_SetItem_NoIncref(flags, i, one, std::to_string(__LINE__)); + flags[i] = 1; } return flags; } -// flags[i] corresponds to the i-th input of apply/backward. -PyObject* CreateRequiresGradFlags( - const std::vector& requires_grads) { - PyObject* flags = Ort_PyList_New(requires_grads.size(), "require_grads_list"); - for (size_t i = 0; i < requires_grads.size(); ++i) { - PyObject* value; - if (requires_grads.at(i) != 0) { - value = Py_True; - } else { - value = Py_False; - } - Ort_PyList_SetItem_Incref(flags, i, value, std::to_string(__LINE__)); - } - return flags; -} - -PyObject* CreateInplaceMap( - const std::vector& inplace_map) { - PyObject* inplace_map_obj = Ort_PyList_New(inplace_map.size(), "inplace_map"); - - for (size_t output_index = 0; output_index < inplace_map.size(); ++output_index) { - PyObject* input_index = PyLong_FromLong(inplace_map[output_index]); - Ort_PyList_SetItem_NoIncref(inplace_map_obj, output_index, input_index, std::to_string(__LINE__)); - } - - return inplace_map_obj; -} - -void InvokeRunner( - PyObject* callback_runner, - PyObject* args, - bool is_training_mode, - void** diff_ctx, - std::vector& returned_ortvalues) { - PythonObjectPtr result_ptr(PyObject_CallObject(callback_runner, args), PythonObjectDeleter); - - if (PyErr_Occurred()) { - PyErr_Print(); - ORT_THROW("Python function execution fails with the above information."); - } - - ORT_ENFORCE(PyTuple_Check(result_ptr.get()), "Python function must return a tuple."); - +void ProcessReturnValues(std::vector& results, + bool is_training_mode, + bool safe_run_mode_enabled, + void** diff_ctx, + std::vector& returned_ortvalues) { size_t i = 0; if (diff_ctx) { // Assume that the first input element in the returned tuple is autograd context // from Pytorch. - PyObject* py_obj = PyTuple_GetItem(result_ptr.get(), 0); + ORT_ENFORCE(results.size() > 0, "The returned tuple should have at least one element."); + PyObject* py_obj = results[0]; if (is_training_mode) { if (py_obj == Py_None) { LOGS_DEFAULT(VERBOSE) << "Under training mode, autograd context found to be Py_None."; } else { + GilGuard guard; + const auto refcnt = Py_REFCNT(py_obj); - // We don't need do ref increase here because, python returns tensor.grad_fn as part of - // tuple, who increased the refcnt already (and tensor persist until the backward kernels completed). - // Pytorch also increases refcnt before apply() return, so we should expect refcount >= 2. - // We say "at least" 2 because user could increase the context refcnt as well in their autograd forward() - // and backward() functions. - ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, "."); - if (refcnt > 2) { - LOGS_DEFAULT(VERBOSE) << "Autograd context refcnt > 2, refcnt: " << refcnt; + if (safe_run_mode_enabled) { + // For safe_run_mode_enabled, we expect refcnt >= 2. + // 1. shared_ptr is maintained in torch_interop_utils::PyNodeSharedPointerPool. PyNode is owning + // the context, e.g. THPFunction*. + // 2. results own another reference to the context, while the ownership will be ended after `Invoke` completed. + ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, "."); + + // Own one reference!!! + Py_INCREF(py_obj); + + if (refcnt > 2) { + LOGS_DEFAULT(VERBOSE) << "Autograd context refcnt > 2, refcnt: " << refcnt; + } + } else { + ORT_ENFORCE(refcnt == 1, "Ref count of context should be 1, but actually it's ", refcnt, "."); + + // Own one reference!!! + Py_INCREF(py_obj); } } } else { @@ -184,12 +128,13 @@ void InvokeRunner( // i is 1 if the first element is autograd context. Otherwise, i is 0, so we read from the // first element. - for (; i < static_cast(PyTuple_Size(result_ptr.get())); ++i) { - PyObject* dl_tensor_pointer = PyTuple_GetItem(result_ptr.get(), i); + for (; i < results.size(); ++i) { + PyObject* dl_tensor_pointer = results[i]; if (dl_tensor_pointer == Py_None) { OrtValue empty_ort_value; returned_ortvalues.push_back(empty_ort_value); } else { + GilGuard guard; ORT_ENFORCE(Py_REFCNT(dl_tensor_pointer) == 1, "Ref count of dl_tensor_pointer should be 1."); // Todo (pengwa): be noted we did not pass whether tensor is bool or not. // Currently we assume we don't pass boolean data. @@ -198,73 +143,44 @@ void InvokeRunner( } } -PythonObjectPtr CreatePythonCallArguments( - PyObject* callback, - const size_t len, - const std::vector& requires_grads, - const std::vector>& tensor_args, - const std::vector& tensor_indices, - const std::vector& obj_args, - const std::vector& obj_indices, - const bool is_training_mode, - const std::vector& inplace_map, - const std::string& invoke_id, - const std::string& func_name) { - ORT_ENFORCE(PyCallable_Check(callback), "Forward callback is not callable."); - // The number of variables before those of - // autograd.Function.apply and autograd.Function.backward. - // The extra variables are used to configure the launch - // forward and backward runners. - constexpr int64_t num_control_args = 7; - - // All arguments created for Python call will be destroyed along with PythonObjectPtr. - PythonObjectPtr args(Ort_PyTuple_New(num_control_args + len, "forward_arguments_tuple"), PythonObjectDeleter); - PyObject* tensor_flags = CreateTensorFlags(len, tensor_indices); - PyObject* requires_grad_flags = CreateRequiresGradFlags(requires_grads); - - Ort_PyTuple_SetItem_Incref(args.get(), 0, callback, "callback_function"); - Ort_PyTuple_SetItem_NoIncref(args.get(), 1, requires_grad_flags, "requires_grad_flags"); - Ort_PyTuple_SetItem_NoIncref(args.get(), 2, tensor_flags, "tensor_flags"); - PyObject* is_training_mode_arg = is_training_mode ? Py_True : Py_False; - Ort_PyTuple_SetItem_Incref(args.get(), 3, is_training_mode_arg, "is_training_mode"); - - PyObject* inplace_map_arg = CreateInplaceMap(inplace_map); - Ort_PyTuple_SetItem_NoIncref(args.get(), 4, inplace_map_arg, "inplace_map"); - - PyObject* kernel_invoke_id_arg = PyBytes_FromStringAndSize(invoke_id.c_str(), invoke_id.size()); - Ort_PyTuple_SetItem_NoIncref(args.get(), 5, kernel_invoke_id_arg, "kernel_invoke_id_arg"); - - PyObject* func_name_arg = PyBytes_FromStringAndSize(func_name.c_str(), func_name.size()); - Ort_PyTuple_SetItem_NoIncref(args.get(), 6, func_name_arg, "func_name_arg"); +void PrepareCallArguments(const std::vector>& tensor_args, + const std::vector& tensor_indices, + const std::vector& obj_args, + const std::vector& obj_indices, + std::vector& args, + std::vector& tensor_flags) { + const size_t len = tensor_args.size() + obj_args.size(); + tensor_flags = CreateTensorFlags(len, tensor_indices); + args.resize(len, nullptr); // Tensor inputs to call autograd.Function.apply or autograd.Function.backward. - for (size_t i = 0; i < tensor_args.size(); ++i) { - if (!tensor_args[i].has_value()) { - Ort_PyTuple_SetItem_Incref(args.get(), num_control_args + tensor_indices[i], Py_None, - "non_tensor_args"); - continue; - } + { + GilGuard guard; + for (size_t i = 0; i < tensor_args.size(); ++i) { + if (!tensor_args[i].has_value()) { + Py_INCREF(Py_None); + args[tensor_indices[i]] = Py_None; + continue; + } - // Wrap with DLPack, then transfer to Python for its release. - PyObject* dl_tensor = training::framework::torch::ToDlpack(tensor_args[i].value()); - Ort_PyTuple_SetItem_NoIncref(args.get(), num_control_args + tensor_indices[i], dl_tensor, - "dltensor"); - } + // Wrap with DLPack, then transfer to Python for its release. + PyObject* dl_tensor = training::framework::torch::ToDlpack(tensor_args[i].value()); + args[tensor_indices[i]] = dl_tensor; + } - // Non-tensor inputs to call autograd.Function.apply or autograd.Function.backward. - for (size_t i = 0; i < obj_args.size(); ++i) { - PyObject* pyobj = reinterpret_cast(obj_args[i]); - Ort_PyTuple_SetItem_Incref(args.get(), num_control_args + obj_indices[i], pyobj, - "const_args"); + // Non-tensor inputs to call autograd.Function.apply or autograd.Function.backward. + for (size_t i = 0; i < obj_args.size(); ++i) { + PyObject* pyobj = reinterpret_cast(obj_args[i]); + Py_INCREF(pyobj); + args[obj_indices[i]] = pyobj; + } } - - return args; } void Invoke( const std::string& func_name, - PyObject* runner, - PyObject* callback, + const CustomFunctionRunnerType& runner, + void* callback, const std::vector& requires_grads, const std::vector>& tensor_args, const std::vector& tensor_indices, @@ -273,30 +189,40 @@ void Invoke( const bool is_training_mode, const std::vector& inplace_map, const std::string& invoke_id, + bool safe_run_mode_enabled, void** diff_ctx, std::vector& returned_ortvalues) { const auto len = tensor_args.size() + obj_args.size(); CheckArguments(len, requires_grads, tensor_args, tensor_indices, obj_args, obj_indices); - RefCountTracker::GetInstance().Reset(); - { - PythonObjectPtr args = CreatePythonCallArguments( - callback, - len, - requires_grads, - tensor_args, - tensor_indices, - obj_args, - obj_indices, - is_training_mode, - inplace_map, - invoke_id, - func_name); - - RefCountTracker::GetInstance().DumpDetails("Before Invoke Python Call"); - InvokeRunner(runner, args.get(), is_training_mode, diff_ctx, returned_ortvalues); + std::vector args; + std::vector tensor_flags; + PrepareCallArguments(tensor_args, tensor_indices, obj_args, obj_indices, args, tensor_flags); + + std::vector results; + + std::vector raii_args; + raii_args.reserve(args.size()); + for (auto& arg : args) { + raii_args.emplace_back(arg, PythonObjectDeleter); + } + + results = runner(func_name.c_str(), + callback, + requires_grads, + tensor_flags, + is_training_mode, + inplace_map, + invoke_id.c_str(), + safe_run_mode_enabled, + args); + + std::vector raii_results; + raii_results.reserve(results.size()); + for (auto& arg : results) { + raii_results.emplace_back(arg, PythonObjectDeleter); } - RefCountTracker::GetInstance().DumpDetails("After Python Call Completed"); + ProcessReturnValues(results, is_training_mode, safe_run_mode_enabled, diff_ctx, returned_ortvalues); } void TorchProxy::Forward( @@ -310,6 +236,7 @@ void TorchProxy::Forward( const bool is_training_mode, const std::vector& inplace_map, const std::string& invoke_id, + bool safe_run_mode_enabled, void** diff_ctx, std::vector& returned_ortvalues) { // Semantically, this lock uniquely takes the ownership of TorchProxy @@ -317,12 +244,12 @@ void TorchProxy::Forward( // can be run at one time. std::lock_guard lock(mutex_); // Python-related calls should happen only if guard is alive. - GilGuard guard; - auto runner = OrtTorchFunctionPool::GetInstance().GetForwardRunner(); + CustomFunctionRunnerType runner = OrtTorchFunctionPool::GetInstance().GetForwardRunner(); + Invoke( func_name, runner, - reinterpret_cast(callback), + callback, requires_grads, tensor_args, tensor_indices, @@ -331,6 +258,7 @@ void TorchProxy::Forward( is_training_mode, inplace_map, invoke_id, + safe_run_mode_enabled, diff_ctx, returned_ortvalues); } @@ -344,30 +272,30 @@ void TorchProxy::Backward( const std::vector& obj_indices, const std::vector& inplace_map, const std::string& invoke_id, + bool safe_run_mode_enabled, std::vector& returned_ortvalues) { // Semantically, this lock uniquely takes the ownership of TorchProxy // so that there will be only one of TorchProxy::Forward TorchProxy::Backward // can be run at one time. std::lock_guard lock(mutex_); - // Python-related calls should happen only if guard is alive. - GilGuard guard; - auto runner = OrtTorchFunctionPool::GetInstance().GetBackwardRunner(); - + CustomFunctionRunnerType runner = OrtTorchFunctionPool::GetInstance().GetBackwardRunner(); // Pass all zero since backward inputs don't require gradients. const auto all_input_count = tensor_args.size() + obj_args.size(); const std::vector requires_grads(all_input_count, 0); + Invoke( func_name, runner, - reinterpret_cast(callback), + callback, requires_grads, tensor_args, tensor_indices, obj_args, obj_indices, - true /* is_training_mode */, + false /* is_training_mode */, inplace_map, invoke_id, + safe_run_mode_enabled, nullptr /* context to store */, returned_ortvalues); } @@ -377,6 +305,9 @@ void TorchProxy::RunInputAliasFunction( const std::string& node_proto_str, std::vector& fw_output_to_input_alias_map, std::vector& bw_output_to_input_alias_map) { + // Python-related calls should happen only if guard is alive. + GilGuard guard; + PyObject* input_alias_func = reinterpret_cast(input_alias_function); ORT_ENFORCE(PyCallable_Check(input_alias_func), "input_alias_func is not callable."); diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.h b/orttraining/orttraining/core/framework/torch/torch_proxy.h index 1d5cc1dd69095..450a5048aea44 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.h +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.h @@ -50,6 +50,7 @@ class TorchProxy { const bool is_training_mode, const std::vector& inplace_map, const std::string& invoke_id, + bool safe_run_mode_enabled, void** diff_ctx, std::vector& returned_ortvalues); @@ -62,7 +63,8 @@ class TorchProxy { const std::vector& obj_indices, const std::vector& inplace_map, const std::string& invoke_id, - std::vector& return_args); + bool safe_run_mode_enabled, + std::vector& returned_ortvalues); /** * @brief Run given function to get output to input reuse map. diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 755a8e49d9d12..e675b55c8af8f 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1804,6 +1804,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) { ORT_ENFORCE(utils::HasString(src_attrs.at("func_name"))); attrs.push_back(MakeAttribute("func_name", src_attrs.at("func_name").s())); attrs.push_back(MakeAttribute("output_convention", src_attrs.at("input_convention").s())); + attrs.push_back(MakeAttribute("safe_run_mode", src_attrs.at("safe_run_mode").i())); // input_tensor_types[i] store the type of autograd.Function.apply's ith output. // Note that PythonOpGrad's 0-th input is the Python context generated by PythonOp. diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 8d3f76be20c65..a62ca611b8e7e 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3938,6 +3938,15 @@ Return true if all elements are true and false otherwise. "comment", "comment only for debugging purposes.", AttributeProto::STRING, false) + .Attr( + "safe_run_mode", + "Indicate if the function is running in safe mode or not. " + "Safe mode support common use cases of PyTorch ctx for example, save for backward, mark as dirty," + "or materialize gradient. In this mode, inplace operation is detected on the fly. " + "Unsafe mode is used to run the function faster not considering the above ctx usage." + "Additional requirement running in this mode: provide correct input alias map.", + AttributeProto::INT, + static_cast(1)) .TypeConstraint( "T", OpSchema::all_tensor_types(), @@ -4096,6 +4105,15 @@ Return true if all elements are true and false otherwise. "comment only for debugging purposes.", AttributeProto::STRING, false) + .Attr( + "safe_run_mode", + "Indicate if the function is running in safe mode or not. " + "Safe mode support common use cases of PyTorch ctx for example, save for backward, mark as dirty," + "or materialize gradient. In this mode, inplace operation is detected on the fly. " + "Unsafe mode is used to run the function faster not considering the above ctx usage." + "Additional requirement running in this mode: provide correct input alias map.", + AttributeProto::INT, + static_cast(1)) .TypeConstraint( "T", OpSchema::all_tensor_types(), diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index a5f46d88e4e8b..0c2bfa19e1671 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -316,16 +316,18 @@ void addObjectMethodsForTraining(py::module& m) { m.def("register_forward_runner", [](py::object obj) -> void { #ifdef ENABLE_TRAINING_TORCH_INTEROP + size_t function_address = py::cast(obj); auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); - pool.RegisterForwardRunner(obj.ptr()); + pool.RegisterForwardRunner(function_address); #else ORT_UNUSED_PARAMETER(obj); #endif }); m.def("register_backward_runner", [](py::object obj) -> void { #ifdef ENABLE_TRAINING_TORCH_INTEROP + size_t function_address = py::cast(obj); auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); - pool.RegisterBackwardRunner(obj.ptr()); + pool.RegisterBackwardRunner(function_address); #else ORT_UNUSED_PARAMETER(obj); #endif diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py index fece1be20c96a..d9d1c467a10c1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py @@ -52,10 +52,9 @@ def enable_custom_autograd_support(to_enable=True): if to_enable is True and custom_autograd_function_enabler.state is False: if custom_autograd_function_enabler.already_enabled is False: # Initialize static objects needed to run custom autograd.Function's. - from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function - register_forward_runner(call_python_forward_function) - register_backward_runner(call_python_backward_function) + register_forward_runner(torch_interop_utils.get_custom_function_forward_runner()) + register_backward_runner(torch_interop_utils.get_custom_function_backward_runner()) # Unregister all python functions automatically upon normal interpreter termination. atexit.register(unregister_python_functions) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 8efbe16d7d61d..f10416a9bb0f4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -71,10 +71,10 @@ def symbolic_wrapper(fn): def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None: - """Register a shape inference function for a torch.autograd.Function if there is staticmethod - "infer_shape" defined. + """Register schema summplementaries, for example custom shape inference function and + alias input function for a custom autograd.Function. - The signature of the shape inference function should be: + 1. The signature of the shape inference function should be: @staticmethod def infer_shape( node: onnx.NodeProto, @@ -91,7 +91,7 @@ def infer_shape( Be noted: we only pass in tensor inputs, and return tensor outputs, non-tensor inputs/outputs are ignored. - The signature of the alias input function should be: + 2. The signature of the alias input function should be: @staticmethod def alias_input(node_proto_str: str) -> Tuple[List[int], List[int]]: fw_alias_map = [1, -1, -1] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py deleted file mode 100644 index dd32e2aced561..0000000000000 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ /dev/null @@ -1,707 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - - -import sys -import warnings -from collections import OrderedDict -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch.utils.dlpack import from_dlpack, to_dlpack - -from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils - -from ._fallback import ORTModuleFallbackException, ORTModuleIOError, _FallbackManager, wrap_exception # noqa: F401 -from ._utils import get_rank - - -def _log_warning(message: str): - """Configure the logger for PythonOp runner according to following rules. - 1. If multiple processes are used, the rank will be appended - to the logger name. - 2. The logger will be disabled for non-zero ranks. - """ - if get_rank() == 0: - warnings.warn(f"[rank-{get_rank()}] {message}") - - -class CustomFuncOpKernelInfo: - """Store the kernel-specific information retrieved with the first-time run.""" - - def __init__(self, kernel_invoke_id: str): - # kernel_invoke_id is a string contains session thread id, op kernel creation time stamp in ms, a random int, - # and address of op_kernel pointer. This can guarantee the uniqueness of the key in case of multiple - # instances of a same named PythonOp/PythonOpGrad in one session, or multiple sessions. - self.kernel_invoke_id = kernel_invoke_id - - # For the tensors generated from ORT backend, there is special handling here: - # 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), - # all such tensors will be cloned in case they are saved in context (but ORT backend is not aware of the - # reference, may release the content of the tensor before it is needed in backward). Once - # `autograd.Function.apply` completes, by checking the existence of the tensor in the saved_tensors, - # `_GlobalOpKernelInfoMap` is updated to save the input indices that are saved in context. - # 2. For the subsequent runs, if the input index is in `tensor_input_indices_to_save_in_ctx`, the tensor - # will be cloned before fed into `autograd.Function.apply` as input. - self.tensor_input_indices_to_save_in_ctx: Optional[List[int]] = None - - # To align with PyTorch `ctx.set_materialize_grads(False|True)`` - # materialize_grads_config is a map from output index to (device, dtype, shape) of the output tensor, used - # for materializing the gradient of the output tensor in backward. - self.materialize_grads: bool = False - self.materialize_grads_config: Optional[Dict[int, Tuple[torch.device, torch.dtype, torch.shape]]] = None - - # For the tensors generated from ORT backend, there is special handling here: - # 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), - # all such tensors will be cloned (with gradient) in case they are marked as dirty (if not cloned, but marked - # as dirty, PyTorch will complain the tensor is a leaf, should not be used for inplace update). Once - # `autograd.Function.apply` completes, by checking the existence of the tensor in the dirty_tensors, - # `_GlobalOpKernelInfoMap` is updated to save the input indices that are marked as dirty. - # 2. For the subsequent runs, if the input index is in `tensor_input_indices_for_mark_dirty`, the tensor - # will be cloned (with gradient) before fed into `autograd.Function.apply` as input. - self.tensor_input_indices_for_mark_dirty: Optional[List[int]] = None - - # A list of output indices that needs to be clone before returned, due to inplace update analysis. - self.output_indices_for_clone: Optional[List[int]] = None - - -# Store the kernel-specific information that cannot be retrieved and saved by PyTorch exporter. -# For the infos that can only be retrieved with real run, we try to collect them in the first time run. -# key: kernel_invoke_id, value: CustomFuncOpKernelInfo. -_GlobalOpKernelInfoMap: Dict[str, CustomFuncOpKernelInfo] = {} - - -def _process_inplace_outputs( - kernel_info: CustomFuncOpKernelInfo, - func_name: str, - input_tensors_of_kernel_run: Dict[int, Union[torch.Tensor, None]], - all_outputs_of_kernel_run: List[Union[torch.Tensor, any]], - all_outputs_to_tensor_inputs_reuse_map: List[int], - raw_input_tensors_used_inplace: Dict[int, Union[torch.Tensor, None]], - is_backward=False, -): - """Special handling for in-place reusing in forward or backward. - - Args: - kernel_info: kernel-specific information. - func_name: name of the autograd.Function. - input_tensors_of_kernel_run: all tensor input tensors used to run the autograd.Function forward/backward. - all_outputs_of_kernel_run: all outputs of the autograd.Function forward/backward. - all_outputs_to_tensor_inputs_reuse_map: a list of the same length of kernel outputs, each element representing - which input index it is reusing. If there is no reuse, the value is -1. - raw_input_tensors_used_inplace: a dict of raw input tensors marked as inplace in - `all_outputs_to_tensor_inputs_reuse_map`, the key is the tensor input index, value is the raw input tensor. - is_backward: indicates if this is backward or forward. - - Procedures: - 1. Detect all outputs to tensor inputs reuse mapping. - 2. Validate the detected inplace_map with the registered inplace_map in ORT. For the output tensor, - 2.0 If the reuse mapping value is the same in both inplace_map and detected inplace_map: - 2.0.1 Most likely, we don't need to do anything, except 2.0.2. - 2.0.2 Conditions: - > During forward run, - > The output tensor is reusing one of input tensors, - > The raw input tensor to be reused given from ORT is copied to run the forward kernels - (for two possible reasons: - a. the first time forward run, all inputs will be copied to detect - `tensor_input_indices_to_save_in_ctx`; - b. for every iteration, the input needs to be cloned because it is in - `tensor_input_indices_to_save_in_ctx`). - - In this case, need to copy the output tensor back to the raw input tensor, to make it compatible with - ORT statistically planned buffer reuse. - 2.1 If the reuse mapping value is NOT equal in both inplace_map and detected inplace_map: - 2.1.1 If the detected reuse input index is -1 (e.g. there is NO buffer reuse for this output), - while user specified reuse input index is NOT -1 (ORT planned the reuse), we raise an error. - 2.1.2 If the detected reuse input index is NOT -1 (e.g. there is buffer reuse for this output), - while user specified reuse input index is -1 (ORT did not plan the reuse). We will try to clone the - output tensor before returning to ORT, to align with ORT's NO Buffer reuse plan; otherwise, once the - input buffer is released by ORT memory planner, the output tensor read/write will be corrupted. - Raise a warning to notify users to update inplace_map explicitly for performance consideration. - 2.1.3 Other cases (for example user gives a wrong mapping index compared with detected ones), raise an - error. - 3. Do copies for 2.1.2 cases. - 4. Do copies for 2.0.2 cases. - """ - - log_prefix = f"{func_name}->{'Backward' if is_backward else 'Forward'}: " - input_tensor_address_list = [ - t.data_ptr() if isinstance(t, torch.Tensor) else -1 for t in input_tensors_of_kernel_run.values() - ] - if is_backward: - input_tensor_address_list = [-1, *input_tensor_address_list] # skip the context input - - is_first_time_init = kernel_info.output_indices_for_clone is None - # If this is the first time run, collect runtime tensor reuse mapping. - if is_first_time_init: - # Procedure 1: Detect all outputs to tensor inputs reuse mapping, according to `all_outputs_of_kernel_run` and - # `input_tensors_of_kernel_run`. - assert len(all_outputs_to_tensor_inputs_reuse_map) == len(all_outputs_of_kernel_run), ( - f"{log_prefix}all_outputs_to_tensor_inputs_reuse_map and kernel run outputs should have the same length." - f"all_outputs_to_tensor_inputs_reuse_map: {all_outputs_to_tensor_inputs_reuse_map}, " - f"kernel run outputs: {all_outputs_of_kernel_run}" - ) - - # Detect all outputs to tensor inputs reuse mapping. - detected_reuse_map = [-1] * (len(all_outputs_of_kernel_run)) - for output_index, arg in enumerate(all_outputs_of_kernel_run): - if not isinstance(arg, torch.Tensor): - continue - if arg.data_ptr() in input_tensor_address_list: - input_index = input_tensor_address_list.index(arg.data_ptr()) - detected_reuse_map[output_index] = input_index - - # Procedure 2: Validate the detected inplace_map with the registered inplace_map in ORT. - output_indices_for_clone = ( - [] - ) # collect the output indices that need to be cloned before returned in case 2.1.2. - for output_index, (detected_inplace_index, inplace_index) in enumerate( - zip(detected_reuse_map, all_outputs_to_tensor_inputs_reuse_map) - ): - if inplace_index == detected_inplace_index: - continue - - if ( - inplace_index in raw_input_tensors_used_inplace - and raw_input_tensors_used_inplace[inplace_index] is None - ): - # Use specified inplace input index, but the input tensor is None, which means the input is not - # a tensor, so we don't do further checks. - continue - - # If users register inplace_map (alloc planner will do buffer reuse), - # but detected inplace_map indicates it is NO inplace reusing, we raise an error. - if inplace_index != -1 and detected_inplace_index == -1: - raise RuntimeError( - f"{log_prefix}Fatal: " - f"ONNX Op attribute 'tensor_reuse_map' indicates {output_index}-th output is reusing input " - f"{inplace_index}, but detected inplace_map indicates it is NOT reusing any input. " - "Please update inplace_map explicitly to make it consistent " - f"to avoid undefined behavior due to ORT's memory reuse plan. " - f"inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, " - f"detected inplace_map: {detected_reuse_map}" - ) - - if inplace_index == -1 and detected_inplace_index != -1: - output_indices_for_clone.append(output_index) - continue - - raise RuntimeError( - f"{log_prefix}Fatal: " - f"ONNX Op attribute 'inplace_map' indicates {inplace_index}-th output is reusing " - f"input index {detected_inplace_index}, but detected inplace_map indicates it is reusing " - f"input index {inplace_index}. Please update inplace_map explicitly to avoid undefined behavior " - f"due to memory reuse. inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, " - f"detected inplace_map: {detected_reuse_map}" - ) - - kernel_info.output_indices_for_clone = output_indices_for_clone - - assert kernel_info.output_indices_for_clone is not None - - # Procedure 3: Do copies for 2.1.2 cases. - for output_index in kernel_info.output_indices_for_clone: - _log_warning( - f"{log_prefix}ONNX Op attribute " - f"'tensor_reuse_map' doesn't indicate {output_index}-th output is reusing any input, " - f"but detected inplace_map indicates it is reusing some input index. " - "A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. " - "Please update inplace_map explicitly to avoid such a copy." - ) - all_outputs_of_kernel_run[output_index] = all_outputs_of_kernel_run[output_index].detach().clone() - - # Procedure 4: Do copies for 2.0.2 cases. - if is_backward is False and ( - is_first_time_init - or kernel_info.tensor_input_indices_to_save_in_ctx - or kernel_info.tensor_input_indices_for_mark_dirty - ): - for raw_tensor_input_index, raw_input_tensor in raw_input_tensors_used_inplace.items(): - # raw_input_tensor can be None for backward run, but backward won't go here. - if not isinstance(raw_input_tensor, torch.Tensor): - continue - - # We did not do the check with tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty - # because even for those tensor indices not in - # tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty, we still need to do the - # copy for the first-time run. - if raw_input_tensor.data_ptr() == input_tensor_address_list[raw_tensor_input_index]: - # If the raw input tensor is not copied, we don't need this handling. - continue - - copied = False # for each tensor, we don't do the copy once. - output_indices_reusing_current_raw_input = [ - output_index - for output_index, input_index in enumerate(all_outputs_to_tensor_inputs_reuse_map) - if input_index == raw_tensor_input_index - ] - output_tensor_address = all_outputs_of_kernel_run[output_indices_reusing_current_raw_input[0]].data_ptr() - for output_index in output_indices_reusing_current_raw_input: - assert ( - output_tensor_address == all_outputs_of_kernel_run[output_index].data_ptr() - ), "Outputs reusing the same input tensor should have the same address." - - if not copied: - # Only need a copy once. - # Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False. - raw_input_tensor.requires_grad = False - raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index]) - _log_warning( - f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. " - f"{'Provide output to input reuse mapping to avoid the copy overhead.' if not is_first_time_init else ''}" - ) - copied = True - - all_outputs_of_kernel_run[output_index] = raw_input_tensor - - -def _get_context(forward_tensor_outputs: List[torch.Tensor]) -> Tuple[any, Optional[torch.Tensor]]: - """Search for context among all outputs. - - Note 1: All forward outputs of torch.autograd.Function shared the same gradient function pointer, - so here we just get the first tensor having grad_fn attribute. - (https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267) - - Note 2: Context can be None because NOT all torch.autograd.Function's are differentiable. The function - https://github.com/PyTorch/PyTorch/blob/d701357d921ef167d42c125e65b6f7da6be3ad0f/torch/csrc/autograd/custom_function.cpp#L209? - means if all output of the forward function is not differentiable, then grad_fn will be None (not be set). - - For example, - class Bar(torch.autograd.Function): - # A non-differentiable autograd Function whose forward output - # doesn't have grad_fn attribute. - @staticmethod - def forward(ctx, x): - y = torch.ones_like(x) - return y - - @staticmethod - def backward(ctx, dy): - dx = torch.zeros_like(dy) - return dx - - Returns: - ctx: context of the autograd.Function. - tensor: a tensor that owns the context. - - """ - ctx = None - first_tensor_output = None - for arg in forward_tensor_outputs: - if not isinstance(arg, torch.Tensor) or not hasattr(arg, "grad_fn"): - continue - - if arg.grad_fn is None: - # For the following case, it is possible grad_fn exists, but its value is None, - # so we need to continue to search for the first tensor having a non-None grad_fn. - # - # >>> w = torch.randn(5, 6) - # >>> hasattr(w, "grad_fn") - # True - # >>> w.grad_fn is None - # True - # >>> w, ... = CustomFunc.apply(w) # where CustomFunc forward just return w and other tensors. - # - # Then hasattr(w, "grad_fn") is True, but w.grad_fn is None. - continue - # Use the first context we see because all of arg's share the same one. - ctx = arg.grad_fn - first_tensor_output = arg - break - if first_tensor_output is not None: - assert ctx is not None, "ctx should not be None if first_tensor_output is not None." - return (ctx, first_tensor_output) - - -def _finalize_training_mode_forward( - kernel_invoke_id: str, - func_name: str, - input_tensors_used_for_fw_run: Dict[int, torch.Tensor], - forward_output_tensors: List[Union[torch.Tensor, None]], -): - """Complete the epilogue of forward runner for training mode. - - Args: - kernel_invoke_id: kernel_invoke_id of the PythonOp kernel unique id. - input_tensors_from_ort: input tensors generated from ORT backend. - forward_output_tensors: output tensors of the autograd.Function. - - Things to do: - 1. Try to get context from forward output tensors. - 2. Remove the gradient functions between the current autograd.Function and its input's gradient function, because - in ORT we don't depend on PyTorch's autograd engine. - 3. Register the current autograd.Function's gradient function into our PyNodeSharedPointerPool. - 4. Save kernel-specific information into _GlobalOpKernelInfoMap in the first-time kernel run. - """ - - ctx, tensor_owning_ctx = _get_context(forward_output_tensors) - - kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - - # ctx being None in training mode means the forward function is not differentiable, so backward is not needed. - if ctx is None: - # If this is the first time run, collect kernel-specific information. - if kernel_info.tensor_input_indices_to_save_in_ctx is None: - kernel_info.tensor_input_indices_to_save_in_ctx = [] - - if kernel_info.tensor_input_indices_for_mark_dirty is None: - kernel_info.tensor_input_indices_for_mark_dirty = [] - - return None - - # Filter out the None in the saved_tensors. - saved_tensors = [t for t in ctx.saved_tensors if t is not None] - - ctx.fw_kernel_invoke_id = kernel_invoke_id - - # If this is the first time run, collect kernel-specific information. - if kernel_info.tensor_input_indices_to_save_in_ctx is None: - kernel_info.tensor_input_indices_to_save_in_ctx = [] - if len(saved_tensors): - # Check tensors generated by ORT are in the saved_tensors or not. - # If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap. - kernel_info.tensor_input_indices_to_save_in_ctx = [ - tensor_input_index - for tensor_input_index, tensor in input_tensors_used_for_fw_run.items() - if any(tensor is saved_tensor for saved_tensor in saved_tensors) - ] - _log_warning( - f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to avoid extra copy in every iteration." - ) - kernel_info.materialize_grads = torch_interop_utils.get_materialize_grads(tensor_owning_ctx) - kernel_info.materialize_grads_config = OrderedDict() - if kernel_info.materialize_grads: - for output_index, tensor in enumerate(forward_output_tensors): - if isinstance(tensor, torch.Tensor): - kernel_info.materialize_grads_config[output_index] = ( - tensor.device, - tensor.dtype, - tensor.shape, - ) - - if kernel_info.tensor_input_indices_for_mark_dirty is None: - kernel_info.tensor_input_indices_for_mark_dirty = [] - # Check tensors generated by ORT are marked as dirty(for inplace update) or not. - # If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap. - are_tensors_marked_as_dirty = torch_interop_utils.are_tensors_marked_as_dirty( - tensor_owning_ctx, [t for t in input_tensors_used_for_fw_run.values()] - ) - kernel_info.tensor_input_indices_for_mark_dirty = [ - tensor_input_index - for is_dirty, (tensor_input_index, tensor) in zip( - are_tensors_marked_as_dirty, input_tensors_used_for_fw_run.items() - ) - if is_dirty is True - ] - _log_warning(f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to support leaf node do inplace update.") - - # FORWARD BACKWARD FUNCTION CONNECTIONS - # input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function - # ↓ ↑ - # autograd.Function apply() ------------> autograd.Function backward() - # ↓ | ↑ - # output_1, output_2 --- shared_ptr --- ↑ - # ↓ previous gradient function - - # We remove the edges starting between current autograd.Function's gradient function and - # it's input's gradient function (e.g. AccumulateGrad gradient function), then - # AccumulateGrad gradient function will be destroyed, releasing the reference to input_1 - # (https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21). - # The next edges are stored in Node, with which we can get next gradient function. - # https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527 - torch_interop_utils.clear_grad_fns_for_next_edges(tensor_owning_ctx, saved_tensors) - - # This is mainly to hold grad_fn references by registering it into our PyNodeSharedPointerPool. - torch_interop_utils.register_grad_fn_and_remove_from_autograd(id(ctx), tensor_owning_ctx) - - return ctx - - -def call_python_forward_function( - forward_function: Callable, - requires_grad_flags: List[bool], - tensor_type_flags: List[int], - is_training_mode: bool, - inplace_map: List[int], - kernel_invoke_id: str, - func_name: Union[bytes, str], - *args, -): - """ - This function bridges the gap between ORT variables and autograd.Function.apply. - It conducts basic casting from ORT to PyTorch (before calling "forward_function") and from PyTorch to ORT - (after calling "forward_function"). It also enable autograd in PyTorch. It formats returned outputs, - for example, dropping None's from forward_function's output list. - - The major difference between call_python_forward_function and call_python_backward_function is that - in the forward one, we have extra code to process autograd context from PyTorch. - - Args: - forward_function: pointer to autograd.Function.apply (e.g., MyReLU.apply). - requires_grad_flags: requires_grad_flags[i] indicates if the i-th arg needs gradient. - tensor_type_flags: tensor_type_flags[i] indicates the type of the i-th arg, 0 - non-tensor, 1 - tensor. - is_training_mode: indicates if this model is running under training mode. - inplace_map: a list of the same length of kernel outputs, each element represents which input index - it is reusing. If there is no reuse, the value is -1. - args: inputs to "backward_function". - """ - - try: - func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name - # If this is the first time run, collect runtime tensor reuse mapping. - is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap - if is_first_time_run: - kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) - _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info - - kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - - tensor_input_indices_to_save_in_ctx = kernel_info.tensor_input_indices_to_save_in_ctx - tensor_input_indices_for_mark_dirty = kernel_info.tensor_input_indices_for_mark_dirty - - # Collect the tensor address for all inputs used for run forward, used for reuse detection. - tensor_input_index = 0 - # If the input is reused, we need to save the raw input tensor for special handling. - raw_input_tensors_used_inplace = OrderedDict() # Orders matter here. - input_tensors_used_for_fw_run = OrderedDict() # Orders matter here. - - wrapped_args = [] - for _, (grad_flag, tensor_flag, arg) in enumerate(zip(requires_grad_flags, tensor_type_flags, args)): - if tensor_flag: - # Assume it's a DLPack tensor and convert it to PyTorch tensor. - wrapped_arg = from_dlpack(arg) - - if tensor_input_index in inplace_map: - raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg - - # Only requires gradient when running under training mode - # and the associated tensor has grad_flag=True (i.e., - # "requires_grad=True" in the original PyTorch script). - wrapped_arg.requires_grad = is_training_mode and grad_flag - - # Note1: - # If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the - # copy for all tensors. Otherwise, we only copy the tensors whose indices are in - # tensor_input_indices_to_save_in_ctx. - # Note2: - # For inference mode, we don't need to do the copy because ctx will be None, - # so nothing will be saved for ctx. - # Note3: - # To fix this issue: - # "a leaf Variable that requires grad has been used in an in-place operation." - # If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the - # copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for - # the tensors whose indices are in tensor_input_indices_for_mark_dirty. - if is_training_mode: - if is_first_time_run: - with torch.set_grad_enabled(True): - wrapped_arg = wrapped_arg.clone() - else: - is_input_index_saved_in_ctx = ( - tensor_input_indices_to_save_in_ctx is None - or tensor_input_index in tensor_input_indices_to_save_in_ctx - ) - is_input_index_marked_dirty = ( - tensor_input_indices_for_mark_dirty is None - or tensor_input_index in tensor_input_indices_for_mark_dirty - ) - if is_input_index_saved_in_ctx or is_input_index_marked_dirty: - # when with grad, the leaf tensor after clone will not be leaf. - with torch.set_grad_enabled(is_input_index_marked_dirty): - wrapped_arg = wrapped_arg.clone() - wrapped_arg.requires_grad = is_training_mode and grad_flag - - wrapped_args.append(wrapped_arg) - input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg - - tensor_input_index += 1 - else: - # Use non-tensor as is. It's a PyObject*. - wrapped_args.append(arg) - - with torch.set_grad_enabled(is_training_mode): - # Run autograd.Function.apply(...). - # TODO(pengwa): looks like we are assuming all outputs will be either Tensor or None. - # We should revisit if it is possible to support other types of output, for example int, or, etc. - # But that might also require some work in backend. - result = forward_function(*wrapped_args) - - results = [] - if isinstance(result, torch.Tensor): - results = [result] - elif isinstance(result, (tuple, list)): - results = [r for r in result] - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(result)}."), - ) - - ctx = None - if is_training_mode: - ctx = _finalize_training_mode_forward( - kernel_invoke_id, func_name, input_tensors_used_for_fw_run, results - ) - - final_rets = [ctx] - final_rets.extend(results) - - _process_inplace_outputs( - kernel_info, - func_name, - input_tensors_used_for_fw_run, - final_rets, - inplace_map, - raw_input_tensors_used_inplace, - ) - - dlpacks = [final_rets[0]] - dlpacks.extend(list(to_dlpack(value) if value is not None else None for value in final_rets[1:])) - - # Inside the returned list, the first element is context and the rest - # are DLPack tensors. - return tuple(dlpacks) - except Exception as e: - # Flush buffers. Otherwise, calling this from C++ may lose them. - print("Exception happens when running ", forward_function) - sys.stdout.flush() - sys.stderr.flush() - raise wrap_exception(ORTModuleFallbackException, e) # noqa: B904 - - -def call_python_backward_function( - backward_function: Callable, - requires_grad_flags: List[bool], - tensor_type_flags: List[int], - is_training_mode: bool, - inplace_map: List[int], - kernel_invoke_id: str, - func_name: Union[bytes, str], - *args, -): - """ - This function bridges the gap between ORT variables and autograd.Function.backward. - It conducts basic casting from ORT to PyTorch (before calling "backward_function") - and from PyTorch to ORT (after calling "backward_function"). It formats returned - outputs, example, dropping None's from backward_function's output list. - - Args: - backward_function: pointer to autograd.Function.backward (e.g., MyReLU.backward). - requires_grad_flags: requires_grad_flags[i] indicates if the i-th arg needs gradient. - tensor_type_flags: tensor_type_flags[i] indicates the type of the i-th arg. - is_training_mode: indicates if this model is running under training mode. - inplace_map: a list of the same length of kernel outputs, each element represents which input index - it is reusing. If there is no reuse, the value is -1. - args: inputs to "backward_function". - """ - func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name - with torch.no_grad(): - - def wrap_all_outputs(result): - if isinstance(result, torch.Tensor): - return [to_dlpack(result)] - elif isinstance(result, (tuple, list)): - return [to_dlpack(value) if value is not None else None for value in result] - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(result)}."), - ) - - try: - # If this is the first time run, collect runtime tensor reuse mapping. - if kernel_invoke_id not in _GlobalOpKernelInfoMap: - kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) - _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info - - kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - - # Backward inputs should not require gradients. - assert all(grad_flag == 0 for grad_flag in requires_grad_flags) - - # Prepare inputs for calling Python function. - ctx = args[0] - fw_kernel_invoke_id = ctx.fw_kernel_invoke_id - wrapped_args = [] - - # Collect the tensor address for all inputs used for run backward, used for reuse detection. - tensor_input_index = 1 # skip the context input - # If input is reused, we need to save the raw input tensor for special handling. - raw_input_tensors_used_inplace = OrderedDict() # Orders matter here. - input_tensors_used_for_bw_run = OrderedDict() # Orders matter here. - for grad_input_index, (grad_flag, tensor_flag, arg) in enumerate( - zip(requires_grad_flags, tensor_type_flags, args) - ): - # If an input is a tensor, it is possible we get a None also when it is optional as grad input. - if tensor_flag: - if arg is None: - if _GlobalOpKernelInfoMap[fw_kernel_invoke_id].materialize_grads: - config = _GlobalOpKernelInfoMap[fw_kernel_invoke_id].materialize_grads_config - # ignore the first input, which is the ctx. - device, dtype, shape = config[grad_input_index - 1] - wrapped_arg = torch.zeros(shape, device=device, dtype=dtype) - else: - wrapped_arg = arg - - if grad_input_index in inplace_map: - raw_input_tensors_used_inplace[tensor_input_index] = arg - - else: - # Assume it's a DLPack tensor# and convert it to PyTorch tensor. - wrapped_arg = from_dlpack(arg) - - if grad_input_index in inplace_map: - raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg - - # This may include None values. - input_tensors_used_for_bw_run[tensor_input_index] = wrapped_arg - - if wrapped_arg is not None: - # Only requires gradient when running under training mode - # and the associated tensor has grad_flag=True (i.e., - # "requires_grad=True" in the original PyTorch script). - wrapped_arg.requires_grad = is_training_mode and grad_flag - - wrapped_args.append(wrapped_arg) - tensor_input_index += 1 - else: - # Use non-tensor as is. It's a PyObject*. - wrapped_args.append(arg) - - # Call Python function. - result = backward_function(*wrapped_args) - - # Extract results as DLPack tensor list. - if isinstance(result, torch.Tensor): - result = [result] - elif isinstance(result, (tuple, list)): - result = list(result) - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(result)}."), - ) - - _process_inplace_outputs( - kernel_info, - func_name, - input_tensors_used_for_bw_run, - result, - inplace_map, - raw_input_tensors_used_inplace, - is_backward=True, - ) - - wrapped_returned_args = wrap_all_outputs(result) - - torch_interop_utils.unregister_grad_fn(id(ctx)) - - return tuple(wrapped_returned_args) - except Exception as e: - # Flush buffers. Otherwise, calling this from C++ may lose them. - print("Exception happens when running ", backward_function) - sys.stdout.flush() - sys.stderr.flush() - raise wrap_exception(ORTModuleFallbackException, e) # noqa: B904 diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index d076ecacd6ba5..ff110c431d300 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -24,6 +24,10 @@ STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE = [1] +DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME = "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction" +DEEPSPEED_POST_BACKWARD_FUNCTION_NAME = "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction" +DEEPSPEED_LINEAR_FUNCTION_NAME = "deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3" + def post_processing_enable_zero_stage3_compat( exported_model: ModelProto, @@ -74,7 +78,10 @@ def _get_func_name(node: NodeProto) -> Optional[str]: STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, ) - from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction + from onnxruntime.training.utils.hooks._zero_offload_subscriber import ( + ORTZeROOffloadPostForwardFunction, + ORTZeROOffloadPreForwardFunction, + ) pre_forward_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) @@ -111,9 +118,10 @@ def _get_func_name(node: NodeProto) -> Optional[str]: if input_name == graph_input.name: index_offset_on_python_op_input.append(i) - assert ( - len(index_offset_on_python_op_input) == 1 - ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input} for node {pre_forward_pythonop_node.name}, input {graph_input.name}, {pre_forward_pythonop_node.input}" + assert len(index_offset_on_python_op_input) == 1, ( + f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input} for " + f"node {pre_forward_pythonop_node.name}, input {graph_input.name}, {pre_forward_pythonop_node.input}" + ) reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input) @@ -170,6 +178,34 @@ def _get_func_name(node: NodeProto) -> Optional[str]: exported_model.graph.input.insert(offset, new_input) exported_model.graph.node.insert(0, weight_pull_node) + # Update safe_run_mode attribute for PythonOp. + from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep + + _allowed_unsafe_run_python_op_names = [ + get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction), + get_fully_qualified_class_name(ORTZeROOffloadPostForwardFunction), + func_full_qual_name, + DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME, + DEEPSPEED_POST_BACKWARD_FUNCTION_NAME, + DEEPSPEED_LINEAR_FUNCTION_NAME, + get_fully_qualified_class_name(_IncrementStep), + ] + + for node in exported_model.graph.node: + if node.op_type == "PythonOp": + func_name = None + safe_run_mode_attr = None + for attr in node.attribute: + if attr.name == "func_name": + func_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + if attr.name == "safe_run_mode": + safe_run_mode_attr = attr + + if func_name in _allowed_unsafe_run_python_op_names: + if safe_run_mode_attr: + node.attribute.remove(safe_run_mode_attr) + node.attribute.append(helper.make_attribute("safe_run_mode", 0)) + return exported_model @@ -227,12 +263,8 @@ def _simple_pass_through_infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes - register_shape_inference_function( - "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape - ) - register_shape_inference_function( - "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape - ) + register_shape_inference_function(DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME, _simple_pass_through_infer_shape) + register_shape_inference_function(DEEPSPEED_POST_BACKWARD_FUNCTION_NAME, _simple_pass_through_infer_shape) def _linear_infer_shape( node: NodeProto, @@ -246,7 +278,7 @@ def _linear_infer_shape( output_shape[-1] = shape2[-2] return [output_shape], [tensor_input_dtypes[0]] - register_shape_inference_function("deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape) + register_shape_inference_function(DEEPSPEED_LINEAR_FUNCTION_NAME, _linear_infer_shape) def _register_alias_input_functions(): @@ -274,8 +306,8 @@ def _alias_input(node_proto_str: str): return fw_alias_map, bw_alias_map - register_input_alias_function("deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _alias_input) - register_input_alias_function("deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _alias_input) + register_input_alias_function(DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME, _alias_input) + register_input_alias_function(DEEPSPEED_POST_BACKWARD_FUNCTION_NAME, _alias_input) def _create_weight_retrieval_pythonop( diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.cc new file mode 100644 index 0000000000000..fa54b4929c784 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ctx_pool.h" +#include + +void register_grad_fn_and_remove_from_autograd(py::object ctx, at::Tensor target) { + uint32_t y = reinterpret_cast(ctx.ptr()); + size_t ctx_address = static_cast(y); + + torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); + PyNodeSharedPointerPool::GetInstance().RegisterGradFuncAndRemoveFromAutoGrad(ctx_address, autograd_meta); +} + +void unregister_grad_fn(py::object ctx) { + uint32_t y = reinterpret_cast(ctx.ptr()); + size_t ctx_address = static_cast(y); + PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address); +} + +void clear_all_grad_fns() { + PyNodeSharedPointerPool::GetInstance().ClearAll(); +} diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.h b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.h new file mode 100644 index 0000000000000..e7b101d987d7a --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.h @@ -0,0 +1,96 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +// In PyTorch forward run (e.g. THPFunction_apply), ctx of type THPFunction* (which is also a PyObject*) +// is created (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L673). +// The ctx is used to run user-defined forward function and backward function as the first +// parameter. The same time, a cdata of type std::shared_ptr is created +// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L677), +// cdata is owned by: +// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor owns +// shared_pointer; TensorImpl owns std::unique_ptr; AutogradMeta +// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr, +// e.g, the so called gradient function.) +// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/variable.h#L194 +// b). the consumer operator of forward run outputs, will let its own PyNode/Node (gradient function) +// owns the grad_fn_ (of type std::shared_ptr) of all inputs that require grad. +// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L263 +// BUT, if we run torch computation within PythonOp, b) is lost. So for some cases, where forward outputs +// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr) references +// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0; +// Then when PythonOpGrad runs, segment fault. +// +// So we add b)'s reference in this Pool when forward run returns; dereference from this Pool when backward +// completes, then ~PyNode() is called, which subsequently calls ~THPFunction() destroying ctx. +class PyNodeSharedPointerPool { + public: + static PyNodeSharedPointerPool& GetInstance() { + static PyNodeSharedPointerPool pool; + return pool; + } + + void RegisterGradFuncAndRemoveFromAutoGrad(const size_t& ctx_address, + torch::autograd::AutogradMeta* autograd_meta) { + auto it = grad_fns_.find(ctx_address); + TORCH_CHECK(it == grad_fns_.end(), "should not register grad_fn twice for ctx ", ctx_address); + + // Add new entry if key hasn't been registered. + // After this, the grad_fn_ is removed from torch autograd. + grad_fns_.emplace(ctx_address, std::move(autograd_meta->grad_fn_)); + TORCH_CHECK(autograd_meta->grad_fn_ == nullptr, "fail to remove grad_fn_ from torch autograd for ctx ", + ctx_address); + } + + void UnRegisterGradFunc(const size_t& ctx_address) { + auto it = grad_fns_.find(ctx_address); + TORCH_CHECK(it != grad_fns_.end(), "fail to find grad_fn for ctx ", ctx_address); + + grad_fns_.erase(ctx_address); + } + + void ClearAll() { + grad_fns_.clear(); + } + + private: + PyNodeSharedPointerPool(){}; + ~PyNodeSharedPointerPool(){}; + + PyNodeSharedPointerPool(const PyNodeSharedPointerPool&) = delete; + PyNodeSharedPointerPool& operator=(const PyNodeSharedPointerPool&) = delete; + PyNodeSharedPointerPool(PyNodeSharedPointerPool&&) = delete; + PyNodeSharedPointerPool& operator=(PyNodeSharedPointerPool&&) = delete; + + std::unordered_map> grad_fns_; +}; + +void register_grad_fn_and_remove_from_autograd(py::object ctx, at::Tensor target); + +void unregister_grad_fn(py::object ctx); + +// Supposed to be cleared on python program exit to resolve the following issue: +// When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty, +// PyNode::release_variables() will be called. +// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L168) +// On The other hand, there is a known issue when acquiring GIL in pybind11 destructors, there will be +// probably a deadlock issue. (https://github.com/pybind/pybind11/issues/1446) +// The resolution here, we remove all maintained states before the program exits. + +// A known existing issue: when forward functions are called repeatedly without corresponding backward calls, +// grad functions keep accumulating without releasing, there might be memory (bound to those gradient functions) leaks. +// Ideally this usually won't happen in real training cases, so it should be fine. + +// We CANNOT explicitly clear grad functions before each forward pass to mitigate the known issue above. +// For example: +// loss1 = forward_run(inputs1) +// loss2 = forward_run(inputs2) +// loss = loss1 + loss2 +// loss.backward() +// If we clear grad functions at the beginning of the second `forward_run`, when `loss.backward()` runs, +// the backward path of `loss1` will fail to run PythonOpGrad ops (if there is any). +void clear_all_grad_fns(); diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc new file mode 100644 index 0000000000000..88e93b26e0e22 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ctx_pool.h" +#include "custom_function_shared.h" +#include "custom_function_bw.h" + +#include +#include +#include + +#ifdef NVTX3_ENABLED +#include +#endif + +std::vector custom_function_backward_runner(const char* func_name_char, + void* callback, + const std::vector& requires_grad_flags, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& args) { + pybind11::gil_scoped_acquire gil; + + try { + std::string func_name(func_name_char); + std::string kernel_invoke_id(kernel_invoke_id_char); + bool is_backward = true; + std::string log_prefix = func_name + " -> " + (is_backward ? "Backward " : "Forward "); + + at::AutoGradMode enable_grad(false); + auto it = KernelInfoStore::GetInstance().GetKernelInfoMap().find(kernel_invoke_id); + if (it == KernelInfoStore::GetInstance().GetKernelInfoMap().end()) { + KernelInfoStore::GetInstance().GetKernelInfoMap().emplace( + kernel_invoke_id, + CustomFuncOpKernelInfo(kernel_invoke_id, safe_run_mode_enabled)); + } + + CustomFuncOpKernelInfo& kernel_info = KernelInfoStore::GetInstance().GetKernelInfoMap().at(kernel_invoke_id); + + std::unordered_map raw_input_tensors_used_inplace; + std::unordered_map input_tensors_used_for_bw_run; + + int tensor_input_index = 0; + std::vector raii_call_args; + raii_call_args.reserve(args.size()); + py::object ctx = py::reinterpret_borrow(args[0]); + raii_call_args.push_back(ctx); + for (size_t arg_index = 1; arg_index < args.size(); ++arg_index) { + if (tensor_type_flags[arg_index] != 1) { + raii_call_args.push_back(py::reinterpret_borrow(args[arg_index])); + continue; + } + + at::Tensor tensor; + bool is_dlpack = PyCapsule_IsValid(args[arg_index], "dltensor") != 0; + if (is_dlpack) { + tensor = torch::utils::tensor_fromDLPack(args[arg_index]); + } else { + TORCH_CHECK(args[arg_index] == Py_None, "Only None is supported for non-tensor input."); + PyObject* fw_kernel_invoke_id = PyObject_GetAttrString(ctx.ptr(), "fw_kernel_invoke_id"); + std::string fw_kernel_invoke_id_str = + py::cast(py::reinterpret_borrow(fw_kernel_invoke_id)); + CustomFuncOpKernelInfo& fw_kernel_info = + KernelInfoStore::GetInstance().GetKernelInfoMap().at(fw_kernel_invoke_id_str); + if (fw_kernel_info.materialize_grads) { + auto& config = fw_kernel_info.materialize_grads_config.at(arg_index - 1); + tensor = at::zeros(std::get<0>(config), std::get<1>(config)); // shift by 1 to skip context input. + } + } + + if (kernel_info.safe_run_enabled) { + bool is_input_used_inplace = std::find(inplace_map.begin(), inplace_map.end(), arg_index) != + inplace_map.end(); + if (is_input_used_inplace) { + raw_input_tensors_used_inplace[tensor_input_index] = tensor; + } + input_tensors_used_for_bw_run[tensor_input_index] = tensor; + } + + if (tensor.defined()) { + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); + } else { + raii_call_args.push_back(py::none()); + } + + tensor_input_index++; + } + + py::tuple call_args = py::cast(raii_call_args); + PyObject* result_pyobj; + { + at::AutoGradMode enable_grad(false); + result_pyobj = PyObject_CallObject(reinterpret_cast(callback), call_args.ptr()); + } + + if (PyErr_Occurred()) { + PyErr_Print(); + throw std::runtime_error("Python function execution fails with the above information."); + } + + if (!result_pyobj) { + throw std::runtime_error("Get null result"); + } + + py::object ret = py::reinterpret_steal(result_pyobj); + + std::vector all_outputs_of_kernel_run; + if (THPVariable_Check(ret.ptr())) { + all_outputs_of_kernel_run.push_back(ret); + } else { + TORCH_CHECK(PyTuple_Check(ret.ptr()), "Python function must return a tuple."); + all_outputs_of_kernel_run = ret.cast>(); + } + + if (kernel_info.safe_run_enabled) { + if (kernel_info.is_first_run) { + // key: tensor data address; + // value: if the tensor is defined it records the tensor input index, otherwise, -1. + std::unordered_map input_tensor_address_to_tensor_input_index_map; + input_tensor_address_to_tensor_input_index_map.reserve(input_tensors_used_for_bw_run.size()); + for (auto& input : input_tensors_used_for_bw_run) { + if (input.second.defined()) { + input_tensor_address_to_tensor_input_index_map.insert( + {{static_cast(reinterpret_cast(input.second.data_ptr())), + input.first + 1}}); /* skip the ctx input*/ + } + } + + detect_memory_reuse_once(kernel_info, + input_tensor_address_to_tensor_input_index_map, + all_outputs_of_kernel_run /*all_outputs_of_kernel_run*/, + inplace_map /*all_outputs_to_tensor_inputs_reuse_map*/, + raw_input_tensors_used_inplace, + log_prefix); + } + + process_inplace_outputs(kernel_info, + func_name, + input_tensors_used_for_bw_run, + inplace_map /*all_outputs_to_tensor_inputs_reuse_map*/, + raw_input_tensors_used_inplace, + is_backward /*is_backward*/, + log_prefix, + all_outputs_of_kernel_run /*all_outputs_of_kernel_run*/); + + unregister_grad_fn(ctx); + } + + std::vector rets; + for (auto& py_obj : all_outputs_of_kernel_run) { + PyObject* obj = py_obj.ptr(); + + if (!THPVariable_Check(obj)) { + Py_INCREF(obj); + rets.push_back(obj); + continue; + } + + DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(obj)); + rets.push_back(PyCapsule_New(dlMTensor, "dltensor", dlpack_capsule_destructor)); + } + + if (kernel_info.is_first_run) { + kernel_info.is_first_run = false; + } + return rets; + } catch (const std::exception& e) { + std::cerr << "custom_function_backward_runner failed with " << e.what() << std::endl; + throw; + } +} diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.h b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.h new file mode 100644 index 0000000000000..415f7cc1e5295 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +std::vector custom_function_backward_runner(const char* func_name_char, + void* callback, + const std::vector& requires_grad_flags, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& args); diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc new file mode 100644 index 0000000000000..9e24022b8448d --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc @@ -0,0 +1,516 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ctx_pool.h" +#include "custom_function_shared.h" +#include "custom_function_fw.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef NVTX3_ENABLED +#include +#endif + +static void clear_grad_fns_for_next_edges(at::Tensor& target, + std::vector& saved_tensors) { + // For leaf tensor, there will be a AccumulateGrad (gradient function) created, which owns a + // reference to the tensor. + // For any user saved tensors (with save_for_backward), if the tensor is leaf, we put the map + // {AccumulateGrad*, Tensor*} into grad_fn_to_tensor_map. + std::unordered_map grad_fn_to_tensor_map; + for (auto& t : saved_tensors) { + auto grad_fn = t.grad_fn(); + if (!grad_fn) { + grad_fn = torch::autograd::impl::try_get_grad_accumulator(t); + if (grad_fn) { + TORCH_CHECK(grad_fn_to_tensor_map.find(grad_fn.get()) == grad_fn_to_tensor_map.end(), + "found AccumulateGrad* is used by more than one tensors."); + grad_fn_to_tensor_map.insert({grad_fn.get(), &t}); + } + } + } + + const auto& gradient_func_sptr = target.grad_fn(); + for (auto& edge : gradient_func_sptr->next_edges()) { + torch::autograd::Node* node_func = edge.function.get(); + // If we find the next gradient function is AccumulateGrad, we will check whether its owned + // tensors is in ctx.save_tensors or not. If yes, we skip it; otherwise, we clean the edge, which + // will release the AccumulateGrad function. + if (dynamic_cast(node_func)) { + if (grad_fn_to_tensor_map.find(node_func) != grad_fn_to_tensor_map.end()) { + // skip the edges that connect to saved_tensors. Because when unpack ctx.saved_tensors using + // following code in backward: + // input, = ctx.saved_tensors + // there is such a check: if the saved tensor is a leaf and requires grad, it should have grad accumulator. + // If we clean the edge, then an exception "RuntimeError: No grad accumulator for a saved leaf!" will be thrown + continue; + } else { + edge.function.reset(); + } + } + } +} + +static std::vector are_tensors_marked_as_dirty(at::Tensor& target, + std::vector& tensors_to_check) { + torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); + const auto& grad_fn = autograd_meta->grad_fn_; + auto py_node_fn = dynamic_cast(grad_fn.get()); + TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type."); + THPFunction* py_fn = (THPFunction*)py_node_fn->obj; + std::vector are_tensors_marked_dirty(tensors_to_check.size(), false); + if (!py_fn->dirty_tensors) + return are_tensors_marked_dirty; + + Py_ssize_t num_dirty = PyTuple_GET_SIZE(py_fn->dirty_tensors); + for (const auto j : c10::irange(tensors_to_check.size())) { + bool is_tensor_marked_dirty = false; + for (const auto i : c10::irange(num_dirty)) { + PyObject* obj = PyTuple_GET_ITEM(py_fn->dirty_tensors, i); + const auto& tensor = THPVariable_Unpack(obj); + if (tensor.is_same(tensors_to_check[j])) { + is_tensor_marked_dirty = true; + break; + } + } + + are_tensors_marked_dirty[j] = is_tensor_marked_dirty; + } + + return are_tensors_marked_dirty; +} + +std::optional try_to_get_tensor_owning_context(const py::tuple& forward_output_tensors) { + py::object ctx = py::none(); + std::optional first_tensor_output; + + for (size_t i = 0; i < forward_output_tensors.size(); ++i) { + PyObject* obj = forward_output_tensors[i].ptr(); + if (!THPVariable_Check(obj)) { + continue; + } + + at::Tensor t = THPVariable_Unpack(obj); + if (!t.grad_fn()) { + continue; + } + + // Be noted, in Python, we need additional check as below. + // For the following case, it is possible grad_fn exists, but its value is None, + // so we need to continue to search for the first tensor having a non-None grad_fn. + // + // >>> w = torch.randn(5, 6) + // >>> hasattr(w, "grad_fn") + // True + // >>> w.grad_fn is None + // True + // >>> w, ... = CustomFunc.apply(w) # where CustomFunc forward just return w and other tensors. + // + // Then hasattr(w, "grad_fn") is True, but w.grad_fn is None. + + first_tensor_output = t; + break; + } + + return first_tensor_output; +} + +void get_materialize_grads_once(const py::tuple& forward_output_tensors, + bool need_materialize_grads, + CustomFuncOpKernelInfo& kernel_info) { + kernel_info.materialize_grads = need_materialize_grads; + if (need_materialize_grads) { + for (size_t i = 0; i < forward_output_tensors.size(); ++i) { + PyObject* obj = forward_output_tensors[i].ptr(); + if (!THPVariable_Check(obj)) { + continue; + } + at::Tensor t = THPVariable_Unpack(obj); + kernel_info.materialize_grads_config.insert({i, {t.sizes().vec(), t.options()}}); + } + + static std::once_flag log_warning; + std::call_once(log_warning, []() { + std::cerr << "First-time run initialize kernel info including materialize_grads and materialize_grads_config." + << std::endl; + }); + } +} + +py::object finalize_training_mode_forward( + const std::unordered_map& input_tensors_used_for_fw_run, + const py::tuple& forward_output_tensors, + CustomFuncOpKernelInfo& kernel_info) { + std::optional tensor_owning_ctx = try_to_get_tensor_owning_context(forward_output_tensors); + + if (!tensor_owning_ctx.has_value()) { + // ctx being None in training mode means the forward function is not differentiable, so backward is not needed. + return py::none(); + } + + const std::shared_ptr& cdata = tensor_owning_ctx.value().grad_fn(); + auto py_node_fn = dynamic_cast(cdata.get()); + TORCH_CHECK(py_node_fn != nullptr, "cdata is not PyNode type."); + + // ret is THPFunction + THPFunction* py_fn = (THPFunction*)py_node_fn->obj; + py::object ret = py::reinterpret_steal(torch::autograd::functionToPyObject(cdata)); + + TORCH_CHECK(py_fn != nullptr, "cdata is not THPFunction type."); + + // The way we find saved tensor is aligned with + // "THPFunction_saved_tensors" and "unpack_saved_variables" in PyTorch. + std::vector saved_tensors; + int num_saved = py_fn->saved_variables.size(); + auto saved_for = py_fn->cdata.lock(); + TORCH_INTERNAL_ASSERT(saved_for); + + for (const auto i : c10::irange(num_saved)) { + auto unpacked_var = py_fn->saved_variables[i].unpack(saved_for); + if (unpacked_var.defined()) { + // TODO(pengwa): is it possible we do the copy on demand here instead of do blind + // copy and do detection at the first iteration. + saved_tensors.push_back(unpacked_var); + } + } + + if (kernel_info.is_first_run) { + std::cout << "666666666666666666666666. py_fn->materialize_grads:" << py_fn->materialize_grads << std::endl; + get_materialize_grads_once(forward_output_tensors, py_fn->materialize_grads, kernel_info); + + if (kernel_info.safe_run_enabled) { + for (auto& pair : input_tensors_used_for_fw_run) { + auto& tensor = pair.second; + bool found = false; + for (auto& t : saved_tensors) { + if (t.is_same(tensor)) { + found = true; + break; + } + } + kernel_info.tensor_input_indices_to_save_in_ctx[pair.first] = found; + } + + // Check tensors generated by ORT are marked as dirty(for inplace update) or not . + // If yes, save the input index of the tensor in the KernelInfoStore::GetInstance().GetKernelInfoMap(). + std::vector tensors_to_check; + tensors_to_check.reserve(input_tensors_used_for_fw_run.size()); + for (auto& pair : input_tensors_used_for_fw_run) { + tensors_to_check.push_back(pair.second); + } + + std::vector are_dirty = are_tensors_marked_as_dirty(tensor_owning_ctx.value(), tensors_to_check); + size_t index = 0; + for (auto& pair : input_tensors_used_for_fw_run) { + kernel_info.tensor_input_indices_for_mark_dirty[pair.first] = are_dirty[index]; + + index += 1; + } + + static std::once_flag log_warning; + std::call_once(log_warning, []() { + std::cerr << "First time run initialize kernel info including saved_for_forward, and mark_dirty infos." << std::endl; + }); + } + } + + // #FORWARD BACKWARD FUNCTION CONNECTIONS + // #input_1(leaf, constructed by from_dlpack) < -- --reference-- --AccumulateGrad gradient function + // # ↓ ↑ + // #autograd.Function apply()-- -- -- -- -- --> autograd.Function backward() + // # ↓ | ↑ + // #output_1, output_2-- - shared_ptr < PyNode> -- - ↑ + // # ↓ previous gradient function + + // #We remove the edges starting between current autograd.Function's gradient function and + // #it 's input' s gradient function(e.g.AccumulateGrad gradient function), then + // #AccumulateGrad gradient function will be destroyed, releasing the reference to input_1 + // #(https: //github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21). + // #The next edges are stored in Node, with which we can get next gradient function. + // #https: // github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527 + + clear_grad_fns_for_next_edges(tensor_owning_ctx.value(), saved_tensors); + + // This is mainly to hold grad_fn references by registering it into our PyNodeSharedPointerPool. + register_grad_fn_and_remove_from_autograd(ret, tensor_owning_ctx.value()); + + return ret; +} + +static py::object get_mockup_context_class() { + static py::object kclass_obj; + + if (!kclass_obj.ptr()) { + // Load the module object + auto module = + py::reinterpret_steal( + PyImport_ImportModule("onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils.fake_ctx")); + if (!module.ptr()) { + PyErr_Print(); + throw std::runtime_error("Fails to import the module."); + } + + auto python_class = py::reinterpret_steal(PyObject_GetAttrString(module.ptr(), "FakeContext")); + if (!PyCallable_Check(python_class.ptr())) { + throw std::runtime_error("Cannot instantiate the Python class"); + } + + kclass_obj = py::reinterpret_borrow(python_class.ptr()); + } + + return kclass_obj; +} + +std::vector custom_function_forward_runner(const char* func_name_char, + void* callback, + const std::vector& requires_grad_flags, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& args) { + try { + pybind11::gil_scoped_acquire gil; + + std::string func_name(func_name_char); + std::string kernel_invoke_id(kernel_invoke_id_char); + bool is_backward = false; + std::string log_prefix = func_name + " -> " + (is_backward ? "Backward " : "Forward "); + +#ifdef NVTX3_ENABLED + nvtxRangePushA(std::string(func_name + ".fw").c_str()); +#endif + + auto it = KernelInfoStore::GetInstance().GetKernelInfoMap().find(kernel_invoke_id); + if (it == KernelInfoStore::GetInstance().GetKernelInfoMap().end()) { + KernelInfoStore::GetInstance().GetKernelInfoMap().emplace( + kernel_invoke_id, + CustomFuncOpKernelInfo(kernel_invoke_id, safe_run_mode_enabled)); + } + + CustomFuncOpKernelInfo& kernel_info = KernelInfoStore::GetInstance().GetKernelInfoMap().at(kernel_invoke_id); + + std::unordered_map raw_input_tensors_used_inplace; + std::unordered_map input_tensors_used_for_fw_run; + + int tensor_input_index = 0; + std::vector raii_call_args; + if (kernel_info.safe_run_enabled) { + raii_call_args.reserve(args.size()); + } else { + auto python_class = get_mockup_context_class(); + // Creates an instance of the class + PyObject* object = PyObject_CallObject(python_class.ptr(), nullptr); + raii_call_args.reserve(args.size() + 1); + raii_call_args.push_back(py::reinterpret_steal(object)); + } + + for (size_t arg_index = 0; arg_index < args.size(); ++arg_index) { + bool is_tensor = (tensor_type_flags[arg_index] == 1); + if (!is_tensor) { + raii_call_args.push_back(py::reinterpret_borrow(args[arg_index])); + continue; + } + + // Assume it's a DLPack tensor and convert it to PyTorch tensor. + TORCH_CHECK(PyCapsule_IsValid(args[arg_index], "dltensor") != 0, "found invalid pycapsule"); + at::Tensor tensor = torch::utils::tensor_fromDLPack(args[arg_index]); + bool requires_grad = requires_grad_flags[arg_index] && is_training_mode; + tensor.requires_grad_(requires_grad); + + if (kernel_info.safe_run_enabled) { + bool is_input_used_inplace = (std::find(inplace_map.begin(), inplace_map.end(), tensor_input_index) != + inplace_map.end()); + if (is_input_used_inplace) { + raw_input_tensors_used_inplace[tensor_input_index] = tensor; + } + + if (kernel_info.is_first_run) { + at::Tensor tensor_clone; + if (is_training_mode) { + at::AutoGradMode enable_grad(true); + tensor_clone = tensor.clone(); + tensor_clone.requires_grad_(requires_grad); + } else { + tensor_clone = tensor; + } + + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor_clone))); + input_tensors_used_for_fw_run[tensor_input_index] = tensor_clone; + } else { + // Saving tensor for backward only affect the training. + bool is_input_index_saved_in_ctx = + is_training_mode && kernel_info.tensor_input_indices_to_save_in_ctx.at(tensor_input_index); + + bool is_input_index_marked_dirty = + kernel_info.tensor_input_indices_for_mark_dirty.at(tensor_input_index); + + if (is_input_index_saved_in_ctx || is_input_index_marked_dirty) { + at::AutoGradMode enable_grad(is_input_index_marked_dirty); + auto wrapped_arg = tensor.clone(); + wrapped_arg.requires_grad_(requires_grad); + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(wrapped_arg))); + input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg; + } else { + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); + input_tensors_used_for_fw_run[tensor_input_index] = tensor; + } + } + } else { + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); + } + + tensor_input_index++; + } + + if (kernel_info.safe_run_enabled && kernel_info.is_first_run) { + // Initialize some kernel info for the first run. + for (const auto i : c10::irange(input_tensors_used_for_fw_run.size())) { + kernel_info.tensor_input_indices_to_save_in_ctx.insert({{i, false}}); + kernel_info.tensor_input_indices_for_mark_dirty.insert({{i, false}}); + } + } + +#ifdef NVTX3_ENABLED + nvtxRangePushA(std::string(func_name + ".call_func").c_str()); +#endif + + py::tuple call_args = py::cast(raii_call_args); + PyObject* result_pyobj; + { + at::AutoGradMode enable_grad(is_training_mode && kernel_info.safe_run_enabled); + result_pyobj = PyObject_CallObject(reinterpret_cast(callback), call_args.ptr()); + } + +#ifdef NVTX3_ENABLED + nvtxRangePop(); +#endif + + if (PyErr_Occurred()) { + PyErr_Print(); + } + + if (!result_pyobj) { + throw std::runtime_error("Get null result"); + } + + py::object ret = py::reinterpret_steal(result_pyobj); + + py::tuple forward_outputs; + if (THPVariable_Check(ret.ptr())) { // Don't check be tensor? + forward_outputs = py::make_tuple(ret); + } else { + TORCH_CHECK(PyTuple_Check(ret.ptr()), "Python function must return a tuple."); + forward_outputs = ret.cast(); + } + + py::object ctx; + if (is_training_mode) { +#ifdef NVTX3_ENABLED + std::string tag3 = func_name + ".ctx"; + nvtxRangePushA(tag3.c_str()); +#endif + if (kernel_info.safe_run_enabled) { + ctx = finalize_training_mode_forward(input_tensors_used_for_fw_run, forward_outputs, kernel_info); + if (!ctx.is_none()) { + PyObject_SetAttrString(ctx.ptr(), "fw_kernel_invoke_id", py::cast(kernel_invoke_id).ptr()); + } + } else { + if (kernel_info.is_first_run) { + bool need_materialize_grads = true; + get_materialize_grads_once(forward_outputs, need_materialize_grads, kernel_info); + } + + ctx = call_args[0]; + PyObject_SetAttrString(ctx.ptr(), "fw_kernel_invoke_id", py::cast(kernel_invoke_id).ptr()); + } + +#ifdef NVTX3_ENABLED + nvtxRangePop(); +#endif + } else { + ctx = py::none(); + } + + std::vector all_outputs_of_kernel_run; + all_outputs_of_kernel_run.reserve(forward_outputs.size() + 1); + all_outputs_of_kernel_run.push_back(ctx); + for (size_t i = 0; i < forward_outputs.size(); ++i) { + all_outputs_of_kernel_run.push_back(forward_outputs[i]); + } + + if (kernel_info.safe_run_enabled) { + if (kernel_info.is_first_run) { + // key: tensor data address; + // value: if the tensor is defined it records the tensor input index, otherwise, -1. + std::unordered_map input_tensor_address_to_tensor_input_index_map; + input_tensor_address_to_tensor_input_index_map.reserve(input_tensors_used_for_fw_run.size()); + for (auto& input : input_tensors_used_for_fw_run) { + if (input.second.defined()) { + input_tensor_address_to_tensor_input_index_map.insert( + {{static_cast(reinterpret_cast(input.second.data_ptr())), input.first}}); + } + } + + detect_memory_reuse_once(kernel_info, + input_tensor_address_to_tensor_input_index_map, + all_outputs_of_kernel_run /*all_outputs_of_kernel_run*/, + inplace_map /*all_outputs_to_tensor_inputs_reuse_map*/, + raw_input_tensors_used_inplace, + log_prefix); + } + + process_inplace_outputs(kernel_info, + func_name, + input_tensors_used_for_fw_run, + inplace_map /*all_outputs_to_tensor_inputs_reuse_map*/, + raw_input_tensors_used_inplace, + false /*is_backward*/, + log_prefix, + all_outputs_of_kernel_run /*all_outputs_of_kernel_run*/); + } + +#ifdef NVTX3_ENABLED + nvtxRangePushA(std::string(func_name + ".final").c_str()); +#endif + + std::vector rets; + rets.reserve(all_outputs_of_kernel_run.size()); + for (auto& py_obj : all_outputs_of_kernel_run) { + PyObject* obj = py_obj.ptr(); + + if (!THPVariable_Check(obj)) { + Py_INCREF(obj); + rets.push_back(obj); + continue; + } + + DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(obj)); + rets.push_back(PyCapsule_New(dlMTensor, "dltensor", dlpack_capsule_destructor)); + } + +#ifdef NVTX3_ENABLED + nvtxRangePop(); +#endif + + if (kernel_info.is_first_run) { + kernel_info.is_first_run = false; + } + +#ifdef NVTX3_ENABLED + nvtxRangePop(); +#endif + + return rets; + } catch (const std::exception& e) { + std::cerr << "custom_function_forward_runner failed with " << e.what() << std::endl; + throw; + } +} diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.h b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.h new file mode 100644 index 0000000000000..5a908e4cd4e7f --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +std::vector custom_function_forward_runner(const char* func_name_char, + void* callback, + const std::vector& requires_grad_flags, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& tensor_args); diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.cc new file mode 100644 index 0000000000000..f7698b74ab462 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.cc @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ctx_pool.h" +#include "custom_function_shared.h" +#include +#include + +/** + * @brief Special handling for in-place reusing in forward or backward. + * @param kernel_info kernel-specific information. + * @param input_tensor_address_to_tensor_input_index_map + * @param all_outputs_of_kernel_run all outputs of the MSDomain::PythonOp/PythonOpGrad. + * @param all_outputs_to_tensor_inputs_reuse_map + * @param raw_input_tensors_used_inplace a dict of raw input tensors marked as inplace in + `all_outputs_to_tensor_inputs_reuse_map`, the key is the tensor input index, value is the raw input tensor. + * @param log_prefix + * + * Detection procedures: + * 1. Detect all outputs to tensor inputs reuse mapping. + * 2. Validate the detected inplace_map with the registered inplace_map in ORT. For the output tensor, + * 2.0 If the reuse mapping value is the same in both inplace_map and detected inplace_map: + * 2.0.1 Most likely, we don't need to do anything, except 2.0.2. + * 2.0.2 Conditions: + * > During forward run, + * > The output tensor is reusing one of input tensors, + * > The raw input tensor to be reused given from ORT is copied to run the forward kernels + * (for two possible reasons: + * a. the first time forward run, all inputs will be copied to detect + * `tensor_input_indices_to_save_in_ctx`; + * b. for every iteration, the input needs to be cloned because it is in + * `tensor_input_indices_to_save_in_ctx`). + * + * In this case, need to copy the output tensor back to the raw input tensor, to make it compatible with + * ORT statistically planned buffer reuse. + * 2.1 If the reuse mapping value is NOT equal in both inplace_map and detected inplace_map: + * 2.1.1 If the detected reuse input index is -1 (e.g. there is NO buffer reuse for this output), + * while user specified reuse input index is NOT -1 (ORT planned the reuse), we raise an error. + * 2.1.2 If the detected reuse input index is NOT -1 (e.g. there is buffer reuse for this output), + * while user specified reuse input index is -1 (ORT did not plan the reuse). We will try to clone the + * output tensor before returning to ORT, to align with ORT's NO Buffer reuse plan; otherwise, once the + * input buffer is released by ORT memory planner, the output tensor read/write will be corrupted. + * Raise a warning to notify users to update inplace_map explicitly for performance consideration. + * 2.1.3 Other cases (for example user gives a wrong mapping index compared with detected ones), raise an + * error. + * 3. Do copies for 2.1.2 cases. + * 4. Do copies for 2.0.2 cases. + */ +void detect_memory_reuse_once( + CustomFuncOpKernelInfo& kernel_info, + const std::unordered_map& input_tensor_address_to_tensor_input_index_map, + const std::vector& all_outputs_of_kernel_run, + const std::vector& all_outputs_to_tensor_inputs_reuse_map, + const std::unordered_map& raw_input_tensors_used_inplace, + const std::string& log_prefix) { + // Procedure 1: Detect all outputs to tensor inputs reuse mapping, according to `all_outputs_of_kernel_run` and + // `input_tensors_of_kernel_run`. + + TORCH_CHECK(all_outputs_to_tensor_inputs_reuse_map.size() == all_outputs_of_kernel_run.size(), + log_prefix + + "all_outputs_to_tensor_inputs_reuse_map and kernel run outputs sizes not expected:" + + std::to_string(all_outputs_to_tensor_inputs_reuse_map.size()) + " vs " + + std::to_string(all_outputs_of_kernel_run.size())); + + // Detect all outputs to tensor inputs reuse mapping. + std::vector detected_reuse_map(all_outputs_of_kernel_run.size(), -1); + for (size_t output_index = 0; output_index < all_outputs_of_kernel_run.size(); ++output_index) { + py::object arg = all_outputs_of_kernel_run[output_index]; + if (!THPVariable_Check(arg.ptr())) { + continue; + } + at::Tensor t = THPVariable_Unpack(arg.ptr()); + size_t t_data_address = static_cast(reinterpret_cast(t.data_ptr())); + if (input_tensor_address_to_tensor_input_index_map.find(t_data_address) != input_tensor_address_to_tensor_input_index_map.end()) { + int tensor_input_index = input_tensor_address_to_tensor_input_index_map.at(t_data_address); + TORCH_CHECK(tensor_input_index != -1, "Reused tensor input index should not be -1"); + detected_reuse_map[output_index] = tensor_input_index; + } + } + + // Procedure 2: Validate the detected inplace_map with the registered inplace_map in ORT. + // collect the output indices that need to be cloned before returned in case 2.1.2. + for (size_t output_index = 0; output_index < all_outputs_of_kernel_run.size(); ++output_index) { + int detected_inplace_index = detected_reuse_map[output_index]; + int inplace_index = all_outputs_to_tensor_inputs_reuse_map[output_index]; + + if (inplace_index == detected_inplace_index) { + continue; + } + + if (raw_input_tensors_used_inplace.count(inplace_index) && + !raw_input_tensors_used_inplace.at(inplace_index).defined()) { + // Use specified inplace input index, but the input tensor is None, which means the input is not + // a tensor, so we don't do further checks. + continue; + } + + // If users register inplace_map (alloc planner will do buffer reuse), + // but detected inplace_map indicates it is NO inplace reusing, we raise an error. + if (inplace_index != -1 && detected_inplace_index == -1) { + throw std::runtime_error( + log_prefix + "Fatal: ONNX Op attribute 'tensor_reuse_map' indicates " + + std::to_string(output_index) + "-th output is reusing input " + + std::to_string(inplace_index) + ", but detected inplace_map indicates it is NOT reusing any input. " + + "Please update inplace_map explicitly to make it consistent " + + "to avoid undefined behavior due to ORT's memory reuse plan. " + + +"detected reused input index: " + std::to_string(detected_inplace_index)); + } + + if (inplace_index == -1 && detected_inplace_index != -1) { + std::cout << log_prefix << "ONNX Op attribute " + << "'tensor_reuse_map' doesn't indicate " << std::to_string(output_index) + << "-th output is reusing any input, " + << "but detected inplace_map indicates it is reusing input index " + << std::to_string(detected_inplace_index) + << ". A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. " + << "Please update inplace_map explicitly to avoid such a copy." << std::endl; + + kernel_info.output_indices_for_clone.push_back(output_index); + continue; + } + + throw std::runtime_error( + log_prefix + "Fatal: ONNX Op attribute 'tensor_reuse_map' indicates " + + std::to_string(output_index) + "-th output is reusing input " + std::to_string(inplace_index) + + " but detected inplace_map indicates it is reusing input index " + + std::to_string(detected_inplace_index) + + ". Please update inplace_map explicitly to avoid undefined behavior due to memory reuse."); + } +} + +void process_inplace_outputs( + const CustomFuncOpKernelInfo& kernel_info, + const std::string& func_name, + const std::unordered_map& input_tensors_used_for_fw_run, + const std::vector& all_outputs_to_tensor_inputs_reuse_map, + const std::unordered_map& raw_input_tensors_used_inplace, + bool is_backward, + const std::string& log_prefix, + std::vector& all_outputs_of_kernel_run) { + // Procedure 3: Do copies for 2.1.2 cases. + for (const size_t& output_index : kernel_info.output_indices_for_clone) { + at::Tensor t = THPVariable_Unpack(all_outputs_of_kernel_run[output_index].ptr()); + auto pp = py::reinterpret_steal(THPVariable_Wrap(t.detach().clone())); + all_outputs_of_kernel_run[output_index] = pp; + } + + // Procedure 4: Do copies for 2.0.2 cases. + if (!is_backward && kernel_info.safe_run_enabled) { + for (auto& pair : raw_input_tensors_used_inplace) { + auto raw_tensor_input_index = pair.first; + auto raw_input_tensor = pair.second; + // raw_input_tensor can be None for backward run, but backward won't go here. + if (!raw_input_tensor.defined()) { + continue; + } + + // We did not do the check with tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty + // because even for those tensor indices not in + // tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty, we still need to do the + // copy for the first-time run. + if (raw_input_tensor.data_ptr() == input_tensors_used_for_fw_run.at(raw_tensor_input_index).data_ptr()) { + // If the raw input tensor is not copied, we don't need this handling. + continue; + } + + // for each tensor, we don't do the copy once. + bool copied = false; + std::vector output_indices_reusing_current_raw_input; + for (size_t output_index = 0; output_index < all_outputs_to_tensor_inputs_reuse_map.size(); ++output_index) { + if (all_outputs_to_tensor_inputs_reuse_map[output_index] == raw_tensor_input_index) { + output_indices_reusing_current_raw_input.push_back(output_index); + } + } + + auto output_tensor_address = + THPVariable_Unpack(all_outputs_of_kernel_run[output_indices_reusing_current_raw_input[0]].ptr()).data_ptr(); + for (size_t& output_index : output_indices_reusing_current_raw_input) { + auto t = THPVariable_Unpack(all_outputs_of_kernel_run[output_index].ptr()); + TORCH_CHECK(output_tensor_address == t.data_ptr(), + "Outputs reusing the same input tensor should have the same address."); + + if (!copied) { + // Only need a copy once. + // Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False. + raw_input_tensor.requires_grad_(false); + raw_input_tensor.copy_(t); + + // Comment below for debugging. + // std::cout << "Copy output tensor " << output_index << " to raw input tensor " << raw_tensor_input_index << "." + // << (!kernel_info.is_first_run + // ? "Provide output to input reuse mapping to avoid the copy overhead." + // : "") + // << std::endl; + copied = true; + } + + all_outputs_of_kernel_run[output_index] = py::reinterpret_steal(THPVariable_Wrap(raw_input_tensor)); + } + } + } +} + +void dlpack_capsule_destructor(PyObject* data) { + if (!PyCapsule_IsValid(data, "dltensor")) { + // early out, see DLPack spec: if a consuming library sets the capsule + // name to something else, they own it and we don't need to do anything + return; + } + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + dlMTensor->deleter(const_cast(dlMTensor)); +} diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.h b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.h new file mode 100644 index 0000000000000..c1c1930aac4cd --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.h @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include + +// Uncomment this line to enable NVTX profiling +// #define NVTX3_ENABLED 1 + +class CustomFuncOpKernelInfo { + public: + CustomFuncOpKernelInfo(const std::string& invoke_id, bool safe_run) { + kernel_invoke_id = invoke_id; + safe_run_enabled = safe_run; + } + + // kernel_invoke_id is a string contains session thread id, op kernel creation time stamp in ms, a random int, + // and address of op_kernel pointer. This can guarantee the uniqueness of the key in case of multiple + // instances of a same named PythonOp/PythonOpGrad in one session, or multiple sessions. + std::string kernel_invoke_id; + + // For the tensors generated from ORT backend, there is special handling here: + // 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), + // all such tensors will be cloned in case they are saved in context (but ORT backend is not aware of the + // reference, may release the content of the tensor before it is needed in backward). Once + // `autograd.Function.apply` completes, by checking the existence of the tensor in the saved_tensors, + // `_GlobalOpKernelInfoMap` is updated to save the input indices that are saved in context. + // 2. For the subsequent runs, if the input index is in `tensor_input_indices_to_save_in_ctx`, the tensor + // will be cloned before fed into `autograd.Function.apply` as input. + std::unordered_map tensor_input_indices_to_save_in_ctx; + + // To align with PyTorch `ctx.set_materialize_grads(False|True)`, default to be true. + // materialize_grads_config is a map from output index to (device, dtype, shape) of the output tensor, used + // for materializing the gradient of the output tensor in backward. + bool materialize_grads{true}; + // key: output index, value: (shape, tensor options including device, layerout, data types, etc) + std::unordered_map, c10::TensorOptions>> materialize_grads_config; + + // For the tensors generated from ORT backend, there is special handling here: + // 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), + // all such tensors will be cloned (with gradient) in case they are marked as dirty (if not cloned, but marked + // as dirty, PyTorch will complain the tensor is a leaf, should not be used for inplace update). Once + // `autograd.Function.apply` completes, by checking the existence of the tensor in the dirty_tensors, + // `_GlobalOpKernelInfoMap` is updated to save the input indices that are marked as dirty. + // 2. For the subsequent runs, if the input index is in `tensor_input_indices_for_mark_dirty`, the tensor + // will be cloned (with gradient) before fed into `autograd.Function.apply` as input. + std::unordered_map tensor_input_indices_for_mark_dirty; + + // A list of output indices that needs to be clone before returned, due to inplace update analysis. + std::vector output_indices_for_clone; + + bool is_first_run{true}; + bool safe_run_enabled{false}; +}; + +void detect_memory_reuse_once( + CustomFuncOpKernelInfo& kernel_info, + const std::unordered_map& input_tensor_address_to_tensor_input_index_map, + const std::vector& all_outputs_of_kernel_run, + const std::vector& all_outputs_to_tensor_inputs_reuse_map, + const std::unordered_map& raw_input_tensors_used_inplace, + const std::string& log_prefix); + +void process_inplace_outputs( + const CustomFuncOpKernelInfo& kernel_info, + const std::string& func_name, + const std::unordered_map& input_tensors_used_for_fw_run, + const std::vector& all_outputs_to_tensor_inputs_reuse_map, + const std::unordered_map& raw_input_tensors_used_inplace, + bool is_backward, + const std::string& log_prefix, + std::vector& all_outputs_of_kernel_run); + +void dlpack_capsule_destructor(PyObject* data); + +class KernelInfoStore { + public: + static KernelInfoStore& GetInstance() { + static KernelInfoStore instance; + return instance; + } + + std::unordered_map& GetKernelInfoMap() { + return kernel_info_map_; + } + + private: + std::unordered_map kernel_info_map_; +}; diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/fake_ctx.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/fake_ctx.py new file mode 100644 index 0000000000000..d295c68c2a155 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/fake_ctx.py @@ -0,0 +1,13 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + + +class FakeContext: + """A mock up class used to represent ctx in unsfafe mode run. + The reason we need ctx to be Python class is: users could assign any attribute to ctx. + """ + + def __init__(self): + pass diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py index 3b6d6050c4c17..fa72f3b134917 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py @@ -8,13 +8,30 @@ from setuptools import Extension, setup # noqa: F401 from torch.utils import cpp_extension -filename = os.path.join(os.path.dirname(__file__), "torch_interop_utils.cc") +source_filenames = [ + "torch_interop_utils.cc", + "ctx_pool.cc", + "custom_function_bw.cc", + "custom_function_fw.cc", + "custom_function_shared.cc", +] + +cur_file_dir = os.path.dirname(__file__) + +header_filenames = [ + # "/usr/local/cuda/include/", # uncomment this line to build nvtx support, + cur_file_dir, +] + extra_compile_args = {"cxx": ["-O3"]} setup( name="torch_interop_utils", ext_modules=[ cpp_extension.CppExtension( - name="torch_interop_utils", sources=[filename], extra_compile_args=extra_compile_args + name="torch_interop_utils", + sources=[os.path.join(cur_file_dir, filename) for filename in source_filenames], + extra_compile_args=extra_compile_args, + include_dirs=header_filenames, ) ], cmdclass={"build_ext": cpp_extension.BuildExtension}, diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc index d36720100e57a..979c409f08074 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc @@ -1,190 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include -#include -// In PyTorch forward run (e.g. THPFunction_apply), ctx of type THPFunction* (which is also a PyObject*) -// is created (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L673). -// The ctx is used to run user-defined forward function and backward function as the first -// parameter. The same time, a cdata of type std::shared_ptr is created -// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L677), -// cdata is owned by: -// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor owns -// shared_pointer; TensorImpl owns std::unique_ptr; AutogradMeta -// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr, -// e.g, the so called gradient function.) -// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/variable.h#L194 -// b). the consumer operator of forward run outputs, will let its own PyNode/Node (gradient function) -// owns the grad_fn_ (of type std::shared_ptr) of all inputs that require grad. -// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L263 -// BUT, if we run torch computation within PythonOp, b) is lost. So for some cases, where forward outputs -// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr) references -// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0; -// Then when PythonOpGrad runs, segment fault. -// -// So we add b)'s reference in this Pool when forward run returns; dereference from this Pool when backward -// completes, then ~PyNode() is called, which subsequently calls ~THPFunction() destroying ctx. -class PyNodeSharedPointerPool { - public: - static PyNodeSharedPointerPool& GetInstance() { - static PyNodeSharedPointerPool pool; - return pool; - }; +#include "ctx_pool.h" +#include "custom_function_fw.h" +#include "custom_function_bw.h" - void RegisterGradFuncAndRemoveFromAutoGrad(const size_t& ctx_address, - torch::autograd::AutogradMeta* autograd_meta) { - auto it = grad_fns_.find(ctx_address); - TORCH_CHECK(it == grad_fns_.end(), "should not register grad_fn twice for ctx ", ctx_address); - - // Add new entry if key hasn't been registered. - // After this, the grad_fn_ is removed from torch autograd. - grad_fns_.emplace(ctx_address, std::move(autograd_meta->grad_fn_)); - TORCH_CHECK(autograd_meta->grad_fn_ == nullptr, "fail to remove grad_fn_ from torch autograd for ctx ", - ctx_address); - }; - - void UnRegisterGradFunc(const size_t& ctx_address) { - auto it = grad_fns_.find(ctx_address); - TORCH_CHECK(it != grad_fns_.end(), "fail to find grad_fn for ctx ", ctx_address); - - grad_fns_.erase(ctx_address); - }; - - void ClearAll() { - grad_fns_.clear(); - } - - private: - PyNodeSharedPointerPool(){}; - ~PyNodeSharedPointerPool(){}; - - PyNodeSharedPointerPool(const PyNodeSharedPointerPool&) = delete; - PyNodeSharedPointerPool& operator=(const PyNodeSharedPointerPool&) = delete; - PyNodeSharedPointerPool(PyNodeSharedPointerPool&&) = delete; - PyNodeSharedPointerPool& operator=(PyNodeSharedPointerPool&&) = delete; - - std::unordered_map> grad_fns_; -}; - -void clear_grad_fns_for_next_edges(at::Tensor target, std::vector saved_tensors) { - // For leaf tensor, there will be a AccumulateGrad (gradient function) created, which owns a - // reference to the tensor. - // For any user saved tensors (with save_for_backward), if the tensor is leaf, we put the map - // {AccumulateGrad*, Tensor*} into grad_fn_to_tensor_map. - std::unordered_map grad_fn_to_tensor_map; - for (auto& t : saved_tensors) { - auto grad_fn = t.grad_fn(); - if (!grad_fn) { - grad_fn = torch::autograd::impl::try_get_grad_accumulator(t); - if (grad_fn) { - TORCH_CHECK(grad_fn_to_tensor_map.find(grad_fn.get()) == grad_fn_to_tensor_map.end(), - "found AccumulateGrad* is used by more than one tensors."); - grad_fn_to_tensor_map.insert({grad_fn.get(), &t}); - } - } - } - - const auto& gradient_func_sptr = target.grad_fn(); - for (auto& edge : gradient_func_sptr->next_edges()) { - torch::autograd::Node* node_func = edge.function.get(); - // If we find the next gradient function is AccumulateGrad, we will check whether its owned - // tensors is in ctx.save_tensors or not. If yes, we skip it; otherwise, we clean the edge, which - // will release the AccumulateGrad function. - if (dynamic_cast(node_func)) { - if (grad_fn_to_tensor_map.find(node_func) != grad_fn_to_tensor_map.end()) { - // skip the edges that connect to saved_tensors. Because when unpack ctx.saved_tensors using - // following code in backward: - // input, = ctx.saved_tensors - // there is such a check: if the saved tensor is a leaf and requires grad, it should have grad accumulator. - // If we clean the edge, then an exception "RuntimeError: No grad accumulator for a saved leaf!" will be thrown - continue; - } else { - edge.function.reset(); - } - } - } -} - -void register_grad_fn_and_remove_from_autograd(size_t ctx_address, at::Tensor target) { - torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); - PyNodeSharedPointerPool::GetInstance().RegisterGradFuncAndRemoveFromAutoGrad(ctx_address, autograd_meta); -} - -void unregister_grad_fn(size_t ctx_address) { - PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address); -} - -// Supposed to be cleared on python program exit to resolve the following issue: -// When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty, -// PyNode::release_variables() will be called. -// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L168) -// On The other hand, there is a known issue when acquiring GIL in pybind11 destructors, there will be -// probably a deadlock issue. (https://github.com/pybind/pybind11/issues/1446) -// The resolution here, we remove all maintained states before the program exits. - -// A known existing issue: when forward functions are called repeatedly without corresponding backward calls, -// grad functions keep accumulating without releasing, there might be memory (bound to those gradient functions) leaks. -// Ideally this usually won't happen in real training cases, so it should be fine. - -// We CANNOT explicitly clear grad functions before each forward pass to mitigate the known issue above. -// For example: -// loss1 = forward_run(inputs1) -// loss2 = forward_run(inputs2) -// loss = loss1 + loss2 -// loss.backward() -// If we clear grad functions at the beginning of the second `forward_run`, when `loss.backward()` runs, -// the backward path of `loss1` will fail to run PythonOpGrad ops (if there is any). -void clear_all_grad_fns() { - PyNodeSharedPointerPool::GetInstance().ClearAll(); -} - -bool get_materialize_grads(at::Tensor target) { - torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); - const auto& grad_fn = autograd_meta->grad_fn_; - auto py_node_fn = dynamic_cast(grad_fn.get()); - TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type."); - THPFunction* py_fn = (THPFunction*)py_node_fn->obj; - return py_fn->materialize_grads; -} - -std::vector are_tensors_marked_as_dirty(at::Tensor target, std::vector tensors_to_check) { - torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); - const auto& grad_fn = autograd_meta->grad_fn_; - auto py_node_fn = dynamic_cast(grad_fn.get()); - TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type."); - THPFunction* py_fn = (THPFunction*)py_node_fn->obj; - std::vector are_tensors_marked_dirty(tensors_to_check.size(), false); - if (!py_fn->dirty_tensors) - return are_tensors_marked_dirty; - - Py_ssize_t num_dirty = PyTuple_GET_SIZE(py_fn->dirty_tensors); - for (const auto j : c10::irange(tensors_to_check.size())) { - bool is_tensor_marked_dirty = false; - for (const auto i : c10::irange(num_dirty)) { - PyObject* obj = PyTuple_GET_ITEM(py_fn->dirty_tensors, i); - const auto& tensor = THPVariable_Unpack(obj); - if (tensor.is_same(tensors_to_check[j])) { - is_tensor_marked_dirty = true; - break; - } - } - - are_tensors_marked_dirty[j] = is_tensor_marked_dirty; - } - - return are_tensors_marked_dirty; -} +size_t get_custom_function_forward_runner() { return reinterpret_cast(&custom_function_forward_runner); } +size_t get_custom_function_backward_runner() { return reinterpret_cast(&custom_function_backward_runner); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("register_grad_fn_and_remove_from_autograd", ®ister_grad_fn_and_remove_from_autograd, - "Increase grad_fn shared pointer reference."); - m.def("unregister_grad_fn", &unregister_grad_fn, "Release grad_fn shared pointer reference."); m.def("clear_all_grad_fns", &clear_all_grad_fns, "Clear all grad_fn shared pointer references."); - m.def("clear_grad_fns_for_next_edges", &clear_grad_fns_for_next_edges, - "Remove reference on next edges' gradient functions."); - m.def("get_materialize_grads", &get_materialize_grads, "Return whether materialize_grads is enabled or not."); - m.def("are_tensors_marked_as_dirty", &are_tensors_marked_as_dirty, "Return whether the tensors are marked dirty or not."); + m.def("get_custom_function_forward_runner", &get_custom_function_forward_runner, "Get custom function forward runner."); + m.def("get_custom_function_backward_runner", &get_custom_function_backward_runner, "Get custom function backward runner."); } diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index 244557c3c1072..b4a518d573998 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. # __init__.py + from onnxruntime.training.utils.ptable import PTable from onnxruntime.training.utils.torch_io_helper import ( ORTModelInputOutputSchemaType, @@ -10,6 +11,11 @@ extract_data_and_schema, unflatten_data_using_schema, ) +from onnxruntime.training.utils.torch_profile_utils import ( + nvtx_function_decorator, + torch_nvtx_range_pop, + torch_nvtx_range_push, +) from onnxruntime.training.utils.torch_type_map import ( onnx_dtype_to_pytorch_dtype, pytorch_scalar_type_to_pytorch_dtype, @@ -22,6 +28,9 @@ "ORTModelInputOutputSchemaType", "extract_data_and_schema", "unflatten_data_using_schema", + "torch_nvtx_range_push", + "torch_nvtx_range_pop", + "nvtx_function_decorator", "pytorch_type_to_onnx_dtype", "onnx_dtype_to_pytorch_dtype", "pytorch_scalar_type_to_pytorch_dtype", diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 61f3b20224a72..e6004319ef5ea 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -17,7 +17,10 @@ from onnxruntime.training.utils import ( ORTModelInputOutputType, extract_data_and_schema, + nvtx_function_decorator, pytorch_type_to_onnx_dtype, + torch_nvtx_range_pop, + torch_nvtx_range_push, unflatten_data_using_schema, ) @@ -173,6 +176,7 @@ def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, sta raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") +@nvtx_function_decorator def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.parameter.Parameter]: """Retrieve the parameters for this module. @@ -187,6 +191,7 @@ def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.par return partitioned_params +@nvtx_function_decorator def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: """Retrieve all the parameters that are offloaded.""" from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -199,6 +204,10 @@ def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.p return all_offloaed_params +# Used to cache the map avoid repeated loop up (X us) overhead during training. +_ModuleToParametersRefs: Dict[torch.nn.Module, List[torch.nn.parameter.Parameter]] = OrderedDict() + + class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): """This function is a common bridge to call original PyTorch's pre_forward_function""" @@ -227,8 +236,7 @@ def forward( tensor_list: the list of tensors, the first args_tensor_count tensors are args, the next kwargs_tensor_count tensors are kwargs, the rest are the parameters for offload. """ - args_tensors = tensor_list[:args_tensor_count] - kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count] + torch_nvtx_range_push("ORTZeROOffloadPreForwardFunction::forward") # For PyTorch runs, the sizes are all 0, it does not need a gradient because # param._detach().requires_grad_(False) is called. @@ -241,41 +249,31 @@ def forward( ctx.dtypes = [p.dtype for p in passed_in_param_tensors] ctx.devices = [p.device for p in passed_in_param_tensors] - args = unflatten_data_using_schema(args_tensors, args_schema) - kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) - # We will re-retrieve the parameter tensors other than use the one passed in input (of size 0 for # those partitioned params). # This is required for ORT run because in ORT graph, the tensor of size 0 will always be size 0 # (this step is not necessary for PyTorch run, because PyTorch will re-use the same tensor # while .data got updated to full-sized data after pre_forward_with_kwargs_function is called). - partitioned_params = _get_params_for_current_module(module) + if module not in _ModuleToParametersRefs: + _ModuleToParametersRefs[module] = _get_params_for_current_module(module) + partitioned_params = _ModuleToParametersRefs[module] ctx.partitioned_params = partitioned_params - assert len(partitioned_params) == len(passed_in_param_tensors) - - f_ret = pre_forward_with_kwargs_function(module, args, kwargs) - - if f_ret is None: - updated_args, updated_kwargs = args, kwargs - else: - assert isinstance(f_ret, tuple) - updated_args, updated_kwargs = f_ret - + pre_forward_with_kwargs_function(module) ctx.module = module - - updated_args_tensors, _ = extract_data_and_schema(updated_args) - updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs) - - rets = tuple(updated_args_tensors + updated_kwargs_tensors) + rets = tuple(tensor_list[: args_tensor_count + kwargs_tensor_count]) rets += tuple([p.detach().requires_grad_(p.requires_grad) for p in partitioned_params]) # PyTorch exporter does not support an empty list of tensors, so we have this check. assert len(rets) != 0 + + torch_nvtx_range_pop() return rets @staticmethod def backward(ctx, *grads): + torch_nvtx_range_push("ORTZeROOffloadPreForwardFunction::backward") + updated_grads = grads input_count = len(updated_grads) - len(ctx.partitioned_params) @@ -302,6 +300,7 @@ def backward(ctx, *grads): zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad) + torch_nvtx_range_pop() return (None, None, None, None, None, None, *zero_grads) @staticmethod @@ -381,6 +380,8 @@ def forward( output_tensors: the list of tensors. """ + torch_nvtx_range_push("ORTZeROOffloadPostForwardFunction::forward") + outputs = unflatten_data_using_schema(output_tensors, output_schema) # STAGE3WARN#3: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. @@ -394,15 +395,20 @@ def forward( ctx.module = module ctx.pre_backward_function = pre_backward_function rets = [o.detach().requires_grad_(o.requires_grad) for o in updated_output_tensors] + torch_nvtx_range_pop() return tuple(rets) @staticmethod def backward(ctx, *grads): + torch_nvtx_range_push("ORTZeROOffloadPostForwardFunction::backward") + updated_args = grads if ctx.pre_backward_function is not None: ret = ctx.pre_backward_function(ctx.module, grads) if ret is not None: updated_args = ret + + torch_nvtx_range_pop() return (None, None, None, None, *updated_args) @staticmethod @@ -467,6 +473,7 @@ def __init__(self, offloader, one_time_init: _ZeROOffloadOneTimeInitializer, ena self._functions = _ZeROOffloadFunctions(one_time_init, self._offloader) self._enable_debug_info = enable_debug_info + @nvtx_function_decorator def pre_forward_module_apply_impl( self, run_rtx: RuntimeStates, @@ -499,17 +506,14 @@ def pre_forward_module_apply_impl( args_tensor_count = len(args_tensors) kwargs_tensor_count = len(kwargs_tensors) - def _wrap_pre_forward_module_hook(module, args, kwargs): - rets = _pre_forward_module_hook(module, args) - updated_args, updated_kwargs = args, kwargs - if rets is not None: - updated_args = rets + @nvtx_function_decorator + def _wrap_pre_forward_module_hook(module): + empty = [] + _pre_forward_module_hook(module, *empty) # STAGE3WARN#5: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. module.ds_grads_remaining = 0 - return updated_args, updated_kwargs - # Need to pass the parameters as input to let the exporter trace the related weights for # current ORTZeROOffloadPreForwardFunction partitioned_params = _get_params_for_current_module(module) @@ -545,6 +549,7 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): return updated_args, updated_kwargs + @nvtx_function_decorator def post_forward_module_apply_impl( self, run_rtx: RuntimeStates, @@ -563,6 +568,7 @@ def post_forward_module_apply_impl( _post_forward_module_hook = self._functions.get("_post_forward_module_hook") + @nvtx_function_decorator def _wrap_post_forward_module_hook(module, input, outputs): # STAGE3WARN#6: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. from deepspeed.runtime.zero.partition_parameters import is_zero_param @@ -580,7 +586,11 @@ def _wrap_post_forward_module_hook(module, input, outputs): self._check_all_tensor(outputs_tensors, module, "post_forward_module_apply_impl input check") updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply( - module, _wrap_post_forward_module_hook, None, outputs_schema, *outputs_tensors + module, + _wrap_post_forward_module_hook, + None, + outputs_schema, + *outputs_tensors, ) self._check_all_tensor(updated_outputs_tensors, module, "post_forward_module_apply_impl output check") @@ -598,6 +608,7 @@ def _wrap_post_forward_module_hook(module, input, outputs): return args, updated_outputs + @nvtx_function_decorator def post_forward_outmost_module_apply_impl( self, run_rtx: RuntimeStates, @@ -611,7 +622,11 @@ def post_forward_outmost_module_apply_impl( self._check_all_tensor(outputs_tensors, module, "post_forward_outmost_module_apply_impl input check") updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply( - module, _end_of_forward_hook, None, outputs_schema, *outputs_tensors + module, + _end_of_forward_hook, + None, + outputs_schema, + *outputs_tensors, ) self._check_all_tensor(updated_outputs_tensors, module, "post_forward_outmost_module_apply_impl output check") @@ -620,6 +635,7 @@ def post_forward_outmost_module_apply_impl( updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) return args, updated_outputs + @nvtx_function_decorator def _check_all_tensor(self, tensor_list: Tuple[torch.Tensor], module: torch.nn.Module, name: str): if not self._enable_debug_info: return diff --git a/orttraining/orttraining/python/training/utils/torch_io_helper.py b/orttraining/orttraining/python/training/utils/torch_io_helper.py index 6d7d978e90054..34cc1ca942a8c 100644 --- a/orttraining/orttraining/python/training/utils/torch_io_helper.py +++ b/orttraining/orttraining/python/training/utils/torch_io_helper.py @@ -10,6 +10,8 @@ import torch +from onnxruntime.training.utils.torch_profile_utils import nvtx_function_decorator + class PrimitiveType: """Helper class for Python primitive types.""" @@ -122,6 +124,7 @@ def _warn_of_constant_inputs(data): ) +@nvtx_function_decorator def extract_data_and_schema( data: ORTModelInputOutputType, constant_as_tensor=False, device: Optional[torch.device] = None ) -> Tuple[List[torch.Tensor], ORTModelInputOutputSchemaType]: @@ -230,6 +233,7 @@ def _flatten_from_data(data: ORTModelInputOutputType, prefix_name: str = ""): return flatten_tensor_data, schemas +@nvtx_function_decorator def unflatten_data_using_schema( data: List[torch.Tensor], schema: ORTModelInputOutputSchemaType ) -> ORTModelInputOutputType: diff --git a/orttraining/orttraining/python/training/utils/torch_profile_utils.py b/orttraining/orttraining/python/training/utils/torch_profile_utils.py new file mode 100644 index 0000000000000..382d7dac142fe --- /dev/null +++ b/orttraining/orttraining/python/training/utils/torch_profile_utils.py @@ -0,0 +1,28 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import torch + + +def torch_nvtx_range_push(msg): + if hasattr(torch.cuda.nvtx, "range_push"): + torch.cuda.nvtx.range_push(msg) + + +def torch_nvtx_range_pop(): + if hasattr(torch.cuda.nvtx, "range_pop"): + torch.cuda.nvtx.range_pop() + + +def nvtx_function_decorator(func): + """Function decorator to record the start and end of NVTX range.""" + + def wrapped_fn(*args, **kwargs): + torch_nvtx_range_push(func.__qualname__) + ret_val = func(*args, **kwargs) + torch_nvtx_range_pop() + return ret_val + + return wrapped_fn diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 958c7d94c4241..bd4fce2cde144 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1533,9 +1533,8 @@ def _run_step(model, input): import warnings - for index in range(10): - count = 0 - with warnings.catch_warnings(record=True) as w: + for _ in range(10): + with warnings.catch_warnings(record=True): input = torch.randn(output_size, device=device, dtype=torch.float) pt_prediction = _run_step(pt_model, input) ort_prediction = _run_step(ort_model, input) @@ -1543,16 +1542,6 @@ def _run_step(model, input): assert_values_are_close(ort_prediction, pt_prediction, rtol=1e-04, atol=1.0) assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-5) - for i in range(len(w)): - msg = str(w[i].message) - if "Add input index to _GlobalOpKernelInfoMap" in msg: - count += 1 - - if index == 0: - assert count == 2 - else: - assert count == 0 - class DupNamedFunction(torch.autograd.Function): @staticmethod diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index 41f4a41a7c38a..3c5ac56cb139a 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -51,6 +51,9 @@ void PythonOpBase::Init(const OpKernelInfo& info) { ORT_THROW_IF_ERROR(info.GetAttr("func_name", &name_)); is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + + safe_run_mode_enabled_ = static_cast(info.GetAttrOrDefault("safe_run_mode", static_cast(1))); + ORT_THROW_IF_ERROR(info.GetAttr("input_convention", &input_convention_)); input_requires_grads_ = info.GetAttrsOrDefault( @@ -144,7 +147,8 @@ void PythonOpBase::RunForward(OpKernelContext* context, // Invoke Python calls. TorchProxy::GetInstance().Forward( name_, - OrtTorchFunctionPool::GetInstance().GetForwardCore(name_), + safe_run_mode_enabled_ ? OrtTorchFunctionPool::GetInstance().GetForwardCore(name_) + : OrtTorchFunctionPool::GetInstance().GetUnsafeForwardCore(name_), input_requires_grads_, args, arg_positions_, @@ -153,6 +157,7 @@ void PythonOpBase::RunForward(OpKernelContext* context, is_training_mode_, all_output_to_tensor_input_reuse_map_, kernel_invoke_id_, + safe_run_mode_enabled_, diff_ctx, returned_ortvalues); @@ -301,7 +306,8 @@ void PythonOpBase::SetOtherOutputs(OpKernelContext* context, std::vector().DataRaw(); - const void* input_tensor_address = context->Input(all_output_to_tensor_input_reuse_map_[output_index])->DataRaw(); + const void* input_tensor_address = + context->Input(all_output_to_tensor_input_reuse_map_[output_index])->DataRaw(); ORT_ENFORCE(tensor_address == input_tensor_address, "PythonOp inplace tensor address mismatch, output index: ", output_index, ", input index: ", all_output_to_tensor_input_reuse_map_[output_index]); @@ -327,7 +333,7 @@ void PythonOpGradBase::Init(const OpKernelInfo& info) { output_tensor_requires_grads_ = info.GetAttrsOrDefault("output_tensor_requires_grads", std::vector()); ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(), "backward tensor output count mismatch"); - + safe_run_mode_enabled_ = static_cast(info.GetAttrOrDefault("safe_run_mode", static_cast(1))); std::vector tensor_output_to_tensor_input_alias_map = info.GetAttrsOrDefault("tensor_reuse_map", std::vector((info.node().OutputDefs().size()), -1)); @@ -371,6 +377,7 @@ void PythonOpGradBase::RunBackward(OpKernelContext* context, const_arg_positions_, all_output_to_tensor_input_reuse_map_, kernel_invoke_id_, + safe_run_mode_enabled_, returned_ortvalues); OrtTorchFunctionPool::GetInstance().UnregisterContext(*context_index_ptr); diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h index d4a53a223abf1..4353859b56735 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h @@ -149,6 +149,8 @@ class PythonOpBase { // Output types of MyReLU.apply(...). std::vector output_tensor_types_; + bool safe_run_mode_enabled_{true}; + private: void AddPrimitiveTypeScalarArgs(); void AddInputTupleArgs(); @@ -193,6 +195,8 @@ class PythonOpGradBase { // Memory reuse map for all outputs. std::vector all_output_to_tensor_input_reuse_map_; + bool safe_run_mode_enabled_{true}; + private: void SetPositions(); diff --git a/setup.py b/setup.py index 44c97937ebe2a..0c2eb19e82c87 100644 --- a/setup.py +++ b/setup.py @@ -488,7 +488,7 @@ def finalize_options(self): ) package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"] - package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"] + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc", "*.h"] package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"] package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [ "*.cpp", diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index e7138e628a52b..bdce0991d6b86 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -10,9 +10,13 @@ stages: UseWebPoolName: true WebCpuPoolName: 'Onnxruntime-Win-CPU-2022' -# This stage is to test if the combined build works on +# The follow section has 12 different build jobs that can be divided into 3 groups: +# 1. Default CPU build with normal win32 linking, without ORT extension +# 2. Default CPU build with wcos linking(use apiset), without ORT extension +# 3. Default CPU build with normal win32 linking with ORT extension +# Each group has 4 jobs that cover: # o Windows ARM64 -# o Windows ARM64EC +# o Windows ARM # o Windows x64 # o Windows x86 # Now we don't have coverage for ARM64EC yet. Will add it. @@ -24,12 +28,26 @@ stages: buildArch: x86 msbuildPlatform: Win32 packageName: x86 - buildparameter: --use_extensions --enable_onnx_tests + buildparameter: --enable_onnx_tests runTests: true buildJava: false buildNodejs: false ort_build_pool_name: 'onnxruntime-Win-CPU-2022' +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm_default + buildArch: x64 + msbuildPlatform: arm + packageName: arm + buildparameter: --arm --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + - template: templates/win-ci.yml parameters: DoCompliance: false @@ -38,7 +56,7 @@ stages: buildArch: x64 msbuildPlatform: arm64 packageName: arm64 - buildparameter: --build_nodejs --arm64 --use_extensions --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + buildparameter: --build_nodejs --arm64 --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe runTests: false buildJava: false buildNodejs: true @@ -52,6 +70,126 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64 + buildparameter: --build_java --build_nodejs --enable_onnx_tests + runTests: true + buildJava: true + buildNodejs: true + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_x86_wcos + artifact_name_suffix: '-wcos' + buildArch: x86 + msbuildPlatform: Win32 + packageName: x86 + buildparameter: --enable_onnx_tests --enable_wcos + runTests: true + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm_wcos + artifact_name_suffix: '-wcos' + buildArch: x64 + msbuildPlatform: arm + packageName: arm + buildparameter: --arm --enable_onnx_tests --enable_wcos --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm64_wcos + artifact_name_suffix: '-wcos' + buildArch: x64 + msbuildPlatform: arm64 + packageName: arm64 + buildparameter: --build_nodejs --enable_wcos --arm64 --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: true + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_x64_wcos + artifact_name_suffix: '-wcos' + buildArch: x64 + msbuildPlatform: x64 + packageName: x64 + buildparameter: --build_java --build_nodejs --enable_onnx_tests --enable_wcos + runTests: true + buildJava: true + buildNodejs: true + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_x86_extension + artifact_name_suffix: '-extension' + buildArch: x86 + msbuildPlatform: Win32 + packageName: x86 + buildparameter: --enable_onnx_tests + runTests: true + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm_extension + artifact_name_suffix: '-extension' + buildArch: x64 + msbuildPlatform: arm + packageName: arm + buildparameter: --arm --use_extensions --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm64_extension + artifact_name_suffix: '-extension' + buildArch: x64 + msbuildPlatform: arm64 + packageName: arm64 + buildparameter: --build_nodejs --arm64 --use_extensions --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: true + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_x64_extension + artifact_name_suffix: '-extension' + buildArch: x64 + msbuildPlatform: x64 + packageName: x64 buildparameter: --build_java --build_nodejs --use_extensions --enable_onnx_tests runTests: true buildJava: true diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 9ef1aed55d58c..537175f6bec73 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.128 + version: 1.0.129 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.128 + version: 1.0.129 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index fd5f61b82a5a8..89c481f267e64 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -193,7 +193,7 @@ stages: - template: nodejs-artifacts-package-and-publish-steps-windows.yml parameters: arch: ${{ parameters.packageName }} - artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.packageName }}' + artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.packageName }}${{ parameters.artifact_name_suffix }}' DoEsrp: ${{ parameters.DoEsrp }} #Upload protoc.exe, which will be used in nuget build for generating C# files @@ -260,7 +260,7 @@ stages: displayName: 'Publish Java temp binaries' inputs: pathtoPublish: '$(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}' - artifactName: 'drop-onnxruntime-java-win-${{ parameters.packageName }}' + artifactName: 'drop-onnxruntime-java-win-${{ parameters.packageName }}${{parameters.artifact_name_suffix}}' - ${{ if eq(parameters['DoCompliance'], 'true') }}: - task: CredScan@3