From 1db1c750488cd6602ea2fa741678b5bd1b16da5f Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 15 Dec 2023 06:33:19 +0800 Subject: [PATCH 01/11] [WebNN EP] WebNN only supports 4-D input and weight for Conv/ConvTranspose (#18703) --- .../webnn/builders/impl/conv_op_builder.cc | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index b37340624f850..e94db2faa80a6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -293,22 +293,39 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - const auto& weight_name = input_defs[1]->Name(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get input's shape."; + return false; + } + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s input dimension: " << input_size + << ". Only conv 2d is supported."; + return false; + } + + std::vector weight_shape; + if (!GetShape(*input_defs[1], weight_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get weight's shape."; + return false; + } + + const auto weight_size = weight_shape.size(); + if (weight_size != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s weight dimension: " << weight_size + << ". Only conv 2d is supported."; + return false; + } + // WebNN CPU backend (XNNPACK) requires the filter operand to be a constant. // https://github.com/google/XNNPACK/blob/master/src/subgraph/convolution-2d.c#L739 - if (device_type == WebnnDeviceType::CPU) { - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); - if (tensor.dims().size() != 4) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] dimension: " << tensor.dims().size() - << " Only conv 2d is supported."; - return false; - } - } else { - LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; - return false; - } + if (device_type == WebnnDeviceType::CPU && !Contains(initializers, input_defs[1]->Name())) { + LOGS(logger, VERBOSE) << "The weight of " << op_type << " [" << name << "] must be known"; + return false; } + return true; } From 6d5ee4d69bd7aac085bd8dca5a391227e628948d Mon Sep 17 00:00:00 2001 From: zesongw Date: Fri, 15 Dec 2023 06:33:44 +0800 Subject: [PATCH 02/11] [WebNN EP] Use explicit padding (#18688) WebNN will remove autoPad option, we need to use explicit padding values. Compute padding values of autopad(same-upper, same-lower) for Op Pool, Conv and ConvTranspose. --- .../webnn/builders/impl/builder_utils.cc | 42 ++--- .../webnn/builders/impl/builder_utils.h | 3 +- .../webnn/builders/impl/conv_op_builder.cc | 153 +++++++++--------- .../webnn/builders/impl/pool_op_builder.cc | 34 ++-- 4 files changed, 111 insertions(+), 121 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc index 516ac7464345b..d147ffbbd181f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc @@ -19,9 +19,10 @@ common::Status ComputeConvPads(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - std::vector& pads_out) { - const int64_t input_size_y = input_shape[2]; - const int64_t input_size_x = input_shape[3]; + std::vector& pads_out, + bool use_nchw) { + const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1]; + const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2]; const int64_t stride_y = onnx_strides[0]; const int64_t stride_x = onnx_strides[1]; const int64_t dilation_y = onnx_dilations[0]; @@ -53,32 +54,17 @@ common::Status HandleAutoPad(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - AutoPadType& auto_pad_type_out) { - auto_pad_type_out = auto_pad_type; - if (auto_pad_type == AutoPadType::NOTSET && onnx_dilations == std::vector{1, 1}) { - { - std::vector same_upper_pads; - ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, - onnx_pads, onnx_strides, onnx_dilations, - AutoPadType::SAME_UPPER, same_upper_pads)); - if (onnx_pads == same_upper_pads) { - auto_pad_type_out = AutoPadType::SAME_UPPER; - return Status::OK(); - } - } - - { - std::vector same_lower_pads; - ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, - onnx_pads, onnx_strides, onnx_dilations, - AutoPadType::SAME_LOWER, same_lower_pads)); - if (onnx_pads == same_lower_pads) { - auto_pad_type_out = AutoPadType::SAME_LOWER; - return Status::OK(); - } - } + std::vector& pads_out, + bool use_nchw) { + if (AutoPadType::SAME_UPPER == auto_pad_type) { + ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, + onnx_pads, onnx_strides, onnx_dilations, + AutoPadType::SAME_UPPER, pads_out, use_nchw)); + } else { + ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, + onnx_pads, onnx_strides, onnx_dilations, + AutoPadType::SAME_LOWER, pads_out, use_nchw)); } - return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h index 76acbca0536ea..cb7c3c6955664 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h @@ -21,7 +21,8 @@ common::Status HandleAutoPad(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - AutoPadType& auto_pad_type_out) ORT_MUST_USE_RESULT; + std::vector& pads_out, + bool use_nchw) ORT_MUST_USE_RESULT; } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index e94db2faa80a6..df0d54e3fd4b4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -44,7 +44,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, const Node& node, emscripten::val& options, const std::vector& strides, const std::vector& dilations, - const std::vector& pads, + std::vector& pads, const logging::Logger& logger) { NodeAttrHelper helper(node); const auto group = helper.Get("group", static_cast(1)); @@ -55,29 +55,85 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, options.set("dilations", emscripten::val::array(dilations)); options.set("groups", group); // Add Padding. - // Usually using autopadding is more efficient than using explicit padding. - // Try to see if we can map explicit padding to auto padding. std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - helper.Get("pads", std::vector{0, 0, 0, 0}), - helper.Get("strides", std::vector{1, 1}), - helper.Get("dilations", std::vector{1, 1}), - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - options.set("autoPad", emscripten::val("same-lower")); + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + if (node.OpType() == "Conv") { + // Calculate explicit padding for autoPad. + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + std::vector pads_out; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], + helper.Get("pads", std::vector{0, 0, 0, 0}), + helper.Get("strides", std::vector{1, 1}), + helper.Get("dilations", std::vector{1, 1}), + auto_pad_type, + pads_out, + model_builder.GetPreferredLayout() == DataLayout::NCHW)); + std::transform(pads_out.begin(), pads_out.end(), pads.begin(), + [](int64_t pad) -> int32_t { return static_cast(pad); }); + } + } else if (node.OpType() == "ConvTranspose") { + // When the 'output_shape' is specificed, the 'output_padding' values + // in options.outputPadding are ignored. + std::vector dim; + std::vector output_padding{0, 0}; + if (helper.HasAttr("output_shape")) { + // Default value of 'output_shape' will be ignore as we already check if + // it's existed. + dim = helper.Get("output_shape", std::vector{-1, -1}); + // Extract the height and width. + std::vector output_shape; + if (dim.size() == 2) { + output_shape = dim; + } else if (dim.size() == 4) { + output_shape = {dim[2], dim[3]}; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); + } + // Padding values are auto generated. + if (helper.HasAttr("kernel_shape")) { + std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); + std::vector total_padding(2); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + for (size_t i = 0; i < 2; i++) { + // Get the dimensions of H and W. + // For NHWC layout, the dimensions of H and W correspond to index 1 and 2. + // For NCHW layout, the dimensions of H and W correspond to index 2 and 3. + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + total_padding[i] = strides[i] * (narrow(input_shape[i + 1]) - 1) + + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + } else { + ORT_RETURN_IF_NOT(model_builder.GetPreferredLayout() == DataLayout::NCHW, + "WebNN GPU backend preferred layout should be NCHW."); + total_padding[i] = strides[i] * (narrow(input_shape[i + 2]) - 1) + + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + } + } + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + pads[0] = total_padding[0] / 2; + pads[1] = total_padding[0] - pads[0]; + pads[2] = total_padding[1] / 2; + pads[3] = total_padding[1] - pads[2]; + if (AutoPadType::SAME_LOWER == auto_pad_type) { + std::swap(pads[0], pads[1]); + std::swap(pads[2], pads[3]); + } + } + } + options.set("outputSizes", emscripten::val::array(output_shape)); } else { - options.set("autoPad", emscripten::val("same-upper")); + output_padding = helper.Get("output_padding", std::vector{0, 0}); + options.set("outputPadding", emscripten::val::array(output_padding)); } } else { - // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], - // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. - const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; - options.set("padding", emscripten::val::array(padding)); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "conv_op_builder only supports Op Conv and ConvTranspose."); } + // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], + // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. + const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; + options.set("padding", emscripten::val::array(padding)); // Add bias if present. if (input_defs.size() > 2) { @@ -198,17 +254,17 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto strides = helper.Get("strides", std::vector{1, 1}); const auto dilations = helper.Get("dilations", std::vector{1, 1}); auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - const auto& weight = input_defs[1]->Name(); + const auto& weight_name = input_defs[1]->Name(); + emscripten::val options = emscripten::val::object(); + ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); if (op_type == "Conv") { - emscripten::val options = emscripten::val::object(); - ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); int groups = options["groups"].as(); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { bool depthwise = (groups == input_shape[3] && groups != 1); options.set("inputLayout", emscripten::val("nhwc")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight, !depthwise)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise)); if (!depthwise) { options.set("filterLayout", emscripten::val("ohwi")); } else { @@ -219,61 +275,10 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N output = model_builder.GetBuilder().call("conv2d", input, filter, options); } else { - emscripten::val options = emscripten::val::object(); - ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi")); - ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight, false)); - } - - // When the 'output_shape' is specificed, the 'output_padding' values - // in options.outputPadding are ignored. - std::vector dim; - std::vector output_padding{0, 0}; - if (helper.HasAttr("output_shape")) { - // Default value of 'output_shape' will be ignore as we already check if - // it's existed. - dim = helper.Get("output_shape", std::vector{-1, -1}); - // Extract the height and width. - std::vector output_shape; - if (dim.size() == 2) { - output_shape = dim; - } else if (dim.size() == 4) { - output_shape = {dim[2], dim[3]}; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); - } - // Padding values are auto generated. - if (helper.HasAttr("kernel_shape")) { - std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); - std::vector total_padding(2); - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - for (size_t i = 0; i < 2; i++) { - // Get the dimensions of H and W. - // For NHWC layout, the dimensions of H and W correspond to index 1 and 2. - // For NCHW layout, the dimensions of H and W correspond to index 2 and 3. - if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { - total_padding[i] = strides[i] * (narrow(input_shape[i + 1]) - 1) + - output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; - } else { - ORT_RETURN_IF_NOT(model_builder.GetPreferredLayout() == DataLayout::NCHW, - "WebNN GPU backend preferred layout should be NCHW."); - total_padding[i] = strides[i] * (narrow(input_shape[i + 2]) - 1) + - output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; - } - } - pads[0] = total_padding[0] - (total_padding[0] / 2); - pads[1] = total_padding[0] / 2; - pads[2] = total_padding[1] - (total_padding[1] / 2); - pads[3] = total_padding[1] / 2; - options.set("padding", emscripten::val::array(pads)); - } - options.set("outputSizes", emscripten::val::array(output_shape)); - } else { - output_padding = helper.Get("output_padding", std::vector{0, 0}); - options.set("outputPadding", emscripten::val::array(output_padding)); + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false)); } emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name()); output = model_builder.GetBuilder().call("convTranspose2d", input, filter, options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index ae7c111c1fe78..739c3b3f38def 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -81,28 +81,26 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto onnx_kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); const auto onnx_strides = helper.Get("strides", std::vector{1, 1}); const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - + auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, onnx_kernel_shape[0], onnx_kernel_shape[1], - onnx_pads, onnx_strides, {1, 1} /* dilations */, - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - options.set("autoPad", "same-lower"); - } else { - options.set("autoPad", "same-upper"); - } - } else { - const std::vector pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], - // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. - const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; - options.set("padding", emscripten::val::array(padding)); + std::vector pads_out; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, onnx_kernel_shape[0], onnx_kernel_shape[1], + onnx_pads, + helper.Get("strides", std::vector{1, 1}), + helper.Get("dilations", std::vector{1, 1}), + auto_pad_type, + pads_out, + model_builder.GetPreferredLayout() == DataLayout::NCHW)); + std::transform(pads_out.begin(), pads_out.end(), pads.begin(), + [](int64_t pad) -> int32_t { return static_cast(pad); }); } + // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], + // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. + const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; + options.set("padding", emscripten::val::array(padding)); const auto ceil_mode = helper.Get("ceil_mode", 0); options.set("roundingType", ceil_mode == 0 ? emscripten::val("floor") From b42d4b8ea650c7b384bfbac1c7edc292c60747a6 Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Fri, 15 Dec 2023 06:43:41 +0800 Subject: [PATCH 03/11] [VitisAI] 1. api compatbile 2. dynamic load onnx (#18470) ### Description 1. Add a backward-compatible API for compiling model. 2. Run-time load vitisai-ep.dll ### Motivation and Context --------- Co-authored-by: Yueqing Zhang Co-authored-by: Zhenze Wang --- cmake/onnxruntime_providers_vitisai.cmake | 10 +- .../core/providers/vitisai/imp/global_api.cc | 270 ++++++++++-------- .../onnxruntime_vitisai_ep.h | 46 --- .../vitisai/include/vaip/global_api.h | 10 + .../vitisai/onnxruntime_vitisai_ep_stub.cc | 30 -- .../vitisai/vitisai_execution_provider.cc | 45 ++- .../vitisai/vitisai_execution_provider.h | 31 +- .../vitisai/vitisai_provider_factory.cc | 37 +-- .../vitisai_provider_factory_creator.h | 3 - .../python/onnxruntime_pybind_state_common.h | 10 - 10 files changed, 199 insertions(+), 293 deletions(-) delete mode 100644 onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h delete mode 100644 onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 7ac4a82c89a76..0951c2d02664d 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -15,16 +15,10 @@ "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" ) - list(REMOVE_ITEM onnxruntime_providers_vitisai_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc") source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - onnxruntime_add_shared_library(onnxruntime_vitisai_ep ${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc) - onnxruntime_add_include_to_target(onnxruntime_vitisai_ep onnxruntime_common) - target_include_directories(onnxruntime_vitisai_ep PRIVATE "${ONNXRUNTIME_ROOT}" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include") - target_link_libraries(onnxruntime_providers_vitisai PUBLIC onnxruntime_vitisai_ep PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json ) - target_compile_definitions(onnxruntime_vitisai_ep - PRIVATE "-DONNXRUNTIME_VITISAI_EP_STUB=1" "-DONNXRUNTIME_VITISAI_EP_EXPORT_DLL=1") + target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) if(NOT MSVC) target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) endif(NOT MSVC) @@ -49,4 +43,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 59bdd43ec997e..b629c8eff9097 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -2,6 +2,10 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #include "vaip/global_api.h" + +#include +#include + #include "./vai_assert.h" #include "core/common/exceptions.h" #include "core/common/logging/logging.h" @@ -10,10 +14,10 @@ #include "core/graph/model.h" #include "core/session/ort_env.h" +#include "core/session/onnxruntime_cxx_api.h" -#include +#include -#include "core/session/onnxruntime_cxx_api.h" #include "vaip/dll_safe.h" #include "vaip/vaip_ort_api.h" #include "vaip/graph.h" @@ -24,28 +28,107 @@ #include "./attr_proto.h" #include "./register_xir_ops.h" -#include "onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h" - #include "onnxruntime_config.h" #include "version_info.h" // version_info.hpp.in using namespace onnxruntime; +using json = nlohmann::json; + +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + vaip_core::OrtApiForVaip* create_org_api_hook(); +struct OrtVitisAIEpAPI { + void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector& ret_domain); + std::vector>* (*compile_onnx_model_3)(const std::string& model_path, + const onnxruntime::Graph& graph, + const char* json_config); + std::vector>* (*compile_onnx_model_with_options)( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); + void Ensure() { + if (handle_) return; + auto full_path = Env::Default().GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); + ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true, &handle_)); + ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary( + handle_, "initialize_onnxruntime_vitisai_ep", reinterpret_cast(&initialize_onnxruntime_vitisai_ep))); + auto status1 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", + reinterpret_cast(&compile_onnx_model_with_options)); + auto status2 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", + reinterpret_cast(&compile_onnx_model_3)); + if (!status1.IsOK() && !status2.IsOK()) { + ::onnxruntime::LogRuntimeError(0, status1, __FILE__, static_cast(__FUNCTION__), __LINE__); + ORT_THROW(status1); + } + } + + private: + void* handle_{}; +}; + +static OrtVitisAIEpAPI s_library_vitisaiep; +static std::string config_to_json_str(const onnxruntime::ProviderOptions& config) { + auto iter = config.find("config_file"); + if (iter == config.end()) { + std::cerr << "Error: Key 'config_file' not found in config" << std::endl; + return ""; + } + const auto& filename = config.at("config_file"); + std::ifstream f(filename); + if (!f.is_open()) { + std::cerr << "Error: Failed to open file: " << filename << std::endl; + return ""; + } + nlohmann::json data; + try { + data = nlohmann::json::parse(f); + } catch (const std::exception& e) { + std::cerr << "Error: Failed to parse JSON from file: " << filename << ", Reason: " << e.what() << std::endl; + return ""; + } + for (const auto& entry : config) { + data[entry.first] = entry.second; + } + try { + return data.dump(); + } catch (const std::exception& e) { + std::cerr << "Error: Failed to convert JSON data to string, Reason: " << e.what() << std::endl; + return ""; + } +} +vaip_core::DllSafe>> compile_onnx_model_with_options( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options) { + if (s_library_vitisaiep.compile_onnx_model_with_options) { + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph, options)); + } else { + auto json_str = config_to_json_str(options); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph, json_str.c_str())); + } +} std::vector initialize_vitisai_ep() { + s_library_vitisaiep.Ensure(); Status status = Status::OK(); try { - OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, "onnxruntime-vitisai-ep"}; + OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, + "onnxruntime-vitisai-ep"}; std::ignore = OrtEnv::GetInstance(lm_info, status); } catch (onnxruntime::OnnxRuntimeException& /*e*/) { } auto domains = std::vector(); domains.reserve(100); - onnxruntime_vitisai_ep::initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); - auto& domainToVersionRangeInstance = - ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); - if (domainToVersionRangeInstance.Map().find("com.xilinx") == - domainToVersionRangeInstance.Map().end()) { + s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); + auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); + if (domainToVersionRangeInstance.Map().find("com.xilinx") == domainToVersionRangeInstance.Map().end()) { vaip::register_xir_ops(domains); } @@ -68,17 +151,14 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.model_delete = [](Model* model) { delete model; }; the_global_api.model_clone = [](const Model& model) -> Model* { auto& logger = logging::LoggingManager::DefaultLogger(); - auto model_proto = - const_cast(model).ToProto(); + auto model_proto = const_cast(model).ToProto(); auto file_path = model.ModelPath().ToPathString(); auto ret = std::make_unique(std::move(model_proto), file_path, nullptr, logger); auto status = ret->MainGraph().Resolve(); vai_assert(status.IsOK(), status.ErrorMessage()); return ret.release(); }; - the_global_api.model_set_meta_data = [](Model& model, const std::string& key, - const std::string& value) - -> void { + the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) -> void { const_cast(model.MetaData())[key] = value; }; the_global_api.model_get_meta_data = [](const Model& model, @@ -97,14 +177,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return m.find(key) != m.end() ? 1 : 0; }; - the_global_api.model_main_graph = [](Model& model) -> Graph& { - return model.MainGraph(); - }; - the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { - return graph.GetModel(); - }; - the_global_api.graph_get_inputs_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.model_main_graph = [](Model& model) -> Graph& { return model.MainGraph(); }; + the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { return graph.GetModel(); }; + the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { auto ret = std::vector(); auto inputs = graph.GetInputs(); for (auto input : inputs) { @@ -113,47 +188,35 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } return vaip_core::DllSafe(std::move(ret)); }; - the_global_api.graph_get_outputs_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { return vaip_core::DllSafe(graph.GetOutputs()); }; - the_global_api.graph_set_outputs = - [](Graph& graph, gsl::span outputs) -> void { + the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) -> void { return graph.SetOutputs(outputs); }; - the_global_api.graph_get_node_arg = - [](const Graph& graph, const std::string& name) -> const NodeArg* { + the_global_api.graph_get_node_arg = [](const Graph& graph, const std::string& name) -> const NodeArg* { return graph.GetNodeArg(name); }; the_global_api.graph_producer_node = [](const Graph& graph, const std::string& name) -> const Node* { return graph.GetProducerNode(name); }; - the_global_api.graph_get_node = [](const Graph& graph, - size_t index) -> const Node* { - return graph.GetNode(index); - }; + the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { return graph.GetNode(index); }; the_global_api.graph_save = vaip::graph_save; the_global_api.graph_fuse = vaip::graph_fuse; the_global_api.graph_remove_node = vaip::graph_remove_node; - the_global_api.graph_add_node = - [](Graph& graph, const std::string& name, const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - vaip_core::NodeAttributes& attributes, - const std::string& domain) -> Node& { - return vaip::graph_add_node( - graph, name, op_type, description, input_args, output_args, - std::move(reinterpret_cast(attributes)), - domain); - }; - - the_global_api.graph_get_all_initialized_tensors = - [](const Graph& graph) -> const InitializedTensorSet& { + the_global_api.graph_add_node = [](Graph& graph, const std::string& name, const std::string& op_type, + const std::string& description, const std::vector& input_args, + const std::vector& output_args, + vaip_core::NodeAttributes& attributes, const std::string& domain) -> Node& { + return vaip::graph_add_node(graph, name, op_type, description, input_args, output_args, + std::move(reinterpret_cast(attributes)), domain); + }; + + the_global_api.graph_get_all_initialized_tensors = [](const Graph& graph) -> const InitializedTensorSet& { return graph.GetAllInitializedTensors(); }; @@ -166,66 +229,46 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; the_global_api.graph_get_consumer_nodes_unsafe = - [](const Graph& graph, - const std::string& node_arg_name) -> vaip_core::DllSafe> { + [](const Graph& graph, const std::string& node_arg_name) -> vaip_core::DllSafe> { return vaip_core::DllSafe(graph.GetConsumerNodes(node_arg_name)); }; - the_global_api.graph_nodes_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { auto& node_refererence = graph.Nodes(); - std::vector nodes((size_t)graph.NumberOfNodes(), nullptr); - std::transform(node_refererence.begin(), node_refererence.end(), - nodes.begin(), [](const Node& n) { return &n; }); + std::vector nodes(static_cast(graph.NumberOfNodes()), nullptr); + std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); return vaip_core::DllSafe(std::move(nodes)); }; - the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { - return graph.Name(); + the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; + the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& stop) { + graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); }; - the_global_api.graph_reverse_dfs_from = - [](const Graph& graph, gsl::span from, - const std::function& enter, - const std::function& leave, - const std::function& stop) { - graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); - }; // node the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; - the_global_api.node_op_type = [](const Node& node) -> const std::string& { - return node.OpType(); - }; - the_global_api.node_op_domain = [](const Node& node) -> const std::string& { - return node.Domain(); - }; - the_global_api.node_get_index = [](const Node& node) -> size_t { - return (size_t)node.Index(); - }; - the_global_api.node_get_name = [](const Node& node) -> const std::string& { - return node.Name(); - }; - the_global_api.node_description = [](const Node& node) -> const std::string& { - return node.Description(); - }; + the_global_api.node_op_type = [](const Node& node) -> const std::string& { return node.OpType(); }; + the_global_api.node_op_domain = [](const Node& node) -> const std::string& { return node.Domain(); }; + the_global_api.node_get_index = [](const Node& node) -> size_t { return static_cast(node.Index()); }; + the_global_api.node_get_name = [](const Node& node) -> const std::string& { return node.Name(); }; + the_global_api.node_description = [](const Node& node) -> const std::string& { return node.Description(); }; - the_global_api.node_get_attributes = - [](Node& node) -> vaip_core::NodeAttributes& { - return reinterpret_cast( - node.GetMutableAttributes()); + the_global_api.node_get_attributes = [](Node& node) -> vaip_core::NodeAttributes& { + return reinterpret_cast(node.GetMutableAttributes()); }; the_global_api.node_type_is_fused = [](const Node& node) { return node.NodeType() == onnxruntime::Node::Type::Fused; }; - the_global_api.node_get_function_body = - [](const Node& node) -> const onnxruntime::Graph& { + the_global_api.node_get_function_body = [](const Node& node) -> const onnxruntime::Graph& { assert(node.GetFunctionBody() != nullptr); return node.GetFunctionBody()->Body(); }; // node_arg - the_global_api.node_arg_get_name_unsafe = - [](const NodeArg& node_arg) -> const std::string& { + the_global_api.node_arg_get_name_unsafe = [](const NodeArg& node_arg) -> const std::string& { return node_arg.Name(); }; the_global_api.node_arg_clone = vaip::node_arg_clone; @@ -236,8 +279,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.node_arg_set_shape_i64 = vaip::node_arg_set_shape_i64; the_global_api.node_arg_get_denotation_unsafe = vaip::node_arg_get_denotation; the_global_api.node_arg_set_denotation = vaip::node_arg_set_denotation; - the_global_api.node_arg_get_const_data_as_tensor = - vaip::node_arg_get_const_data_as_tensor; + the_global_api.node_arg_get_const_data_as_tensor = vaip::node_arg_get_const_data_as_tensor; the_global_api.node_arg_get_element_type = vaip::node_arg_get_element_type; the_global_api.node_arg_set_element_type = [](NodeArg& node_arg, int type) { @@ -299,16 +341,13 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; /// attr proto the_global_api.attr_proto_delete = [](onnx::AttributeProto* v) { delete v; }; - the_global_api.attr_proto_clone = - [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { + the_global_api.attr_proto_clone = [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { return new onnx::AttributeProto(v); }; - the_global_api.attr_proto_get_name = - [](const onnx::AttributeProto& attr_proto) -> const std::string& { + the_global_api.attr_proto_get_name = [](const onnx::AttributeProto& attr_proto) -> const std::string& { return attr_proto.name(); }; - the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, - const std::string& name) { + the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, const std::string& name) { attr_proto->set_name(name); }; the_global_api.attr_proto_new_int = vaip::attr_proto_new_int; @@ -325,17 +364,14 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.attr_proto_get_ints = vaip::attr_proto_get_ints; the_global_api.attr_proto_get_floats = vaip::attr_proto_get_floats; the_global_api.attr_proto_get_strings = vaip::attr_proto_get_strings; - the_global_api.attr_proto_get_type = - [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; + the_global_api.attr_proto_get_type = [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; /// node attributes the_global_api.node_attributes_new = []() { return reinterpret_cast(new NodeAttributes()); }; - the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, - onnx::AttributeProto&& attr) { - reinterpret_cast(p).insert_or_assign(attr.name(), - std::move(attr)); + the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, onnx::AttributeProto&& attr) { + reinterpret_cast(p).insert_or_assign(attr.name(), std::move(attr)); }; the_global_api.node_attributes_delete = [](vaip_core::NodeAttributes* p) { delete reinterpret_cast(p); @@ -349,7 +385,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } return &it->second; }; - the_global_api.node_attributes_get_keys = [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { + the_global_api.node_attributes_get_keys = + [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { auto ret = std::vector(); auto& attr = reinterpret_cast(p); ret.reserve(attr.size()); @@ -359,34 +396,29 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return vaip_core::DllSafe(std::move(ret)); }; /// tensor proto - the_global_api.tensor_proto_get_shape_unsafe = [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { + the_global_api.tensor_proto_get_shape_unsafe = + [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { return vaip_core::DllSafe>(vaip::tensor_proto_get_shape(t)); }; - the_global_api.tensor_proto_data_type = - [](const onnx::TensorProto& t) -> int { return t.data_type(); }; + the_global_api.tensor_proto_data_type = [](const onnx::TensorProto& t) -> int { return t.data_type(); }; the_global_api.tensor_proto_delete = [](onnx::TensorProto* tp) { delete tp; }; - the_global_api.tensor_proto_new_floats = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{ - vaip::tensor_proto_new_floats(name, shape, data)}; + the_global_api.tensor_proto_new_floats = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { + return new onnx::TensorProto{vaip::tensor_proto_new_floats(name, shape, data)}; }; - the_global_api.tensor_proto_new_i32 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i32 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i32(name, shape, data)}; }; - the_global_api.tensor_proto_new_i64 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i64 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i64(name, shape, data)}; }; - the_global_api.tensor_proto_new_i8 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i8 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i8(name, shape, data)}; }; the_global_api.tensor_proto_raw_data_size = vaip::tensor_proto_raw_data_size; diff --git a/onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h b/onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h deleted file mode 100644 index 82f665429c24c..0000000000000 --- a/onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. -#pragma once -#include -#include -#if defined(_WIN32) -#if ONNXRUNTIME_VITISAI_EP_EXPORT_DLL == 1 -#define ONNXRUNTIME_VITISAI_EP_DLL_SPEC __declspec(dllexport) -#else -#define ONNXRUNTIME_VITISAI_EP_DLL_SPEC __declspec(dllimport) -#endif -#else -#define ONNXRUNTIME_VITISAI_EP_DLL_SPEC __attribute__((visibility("default"))) -#endif - -#ifndef USE_VITISAI -#define USE_VITISAI /* mimic VITISAI EP in ORT */ -#endif - -namespace vaip_core { -class ExecutionProvider; -struct OrtApiForVaip; -template -class DllSafe; -} // namespace vaip_core -namespace onnxruntime { -class Graph; -} -struct OrtCustomOpDomain; -namespace onnxruntime_vitisai_ep { - -ONNXRUNTIME_VITISAI_EP_DLL_SPEC void -initialize_onnxruntime_vitisai_ep(vaip_core::OrtApiForVaip* api, - std::vector& ret_domain); -ONNXRUNTIME_VITISAI_EP_DLL_SPEC -vaip_core::DllSafe>> -compile_onnx_model_3(const std::string& model_path, - const onnxruntime::Graph& graph, const char* json_config); -ONNXRUNTIME_VITISAI_EP_DLL_SPEC -int optimize_onnx_model(const std::filesystem::path& model_path_in, - const std::filesystem::path& model_path_out, - const char* json_config); -} // namespace onnxruntime_vitisai_ep - -extern "C" ONNXRUNTIME_VITISAI_EP_DLL_SPEC const vaip_core::OrtApiForVaip* -get_the_global_api(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 8da3882b5af99..c446ab3aefcc5 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -2,6 +2,16 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once +#include +#include +#include + #include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/provider_options.h" +#include "vaip/my_ort.h" +#include "vaip/dll_safe.h" +#include "vaip/custom_op.h" std::vector initialize_vitisai_ep(); +vaip_core::DllSafe>> compile_onnx_model_with_options( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); diff --git a/onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc b/onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc deleted file mode 100644 index 8244c36f822a4..0000000000000 --- a/onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. -#include "vaip/dll_safe.h" -#include "vaip/vaip_ort_api.h" -#include "vaip/custom_op.h" -#include "onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h" -#include -#include -using namespace std; - -namespace onnxruntime_vitisai_ep { -static void my_abort() { - cerr << "please install VitisAI package." << endl; - abort(); -} -using namespace vaip_core; -void initialize_onnxruntime_vitisai_ep(OrtApiForVaip* /*api*/, std::vector& /*domain*/) { - my_abort(); - return; -} // namespace onnxruntime_vitisai_ep -DllSafe>> -compile_onnx_model_3(const std::string& /*model_path*/, const Graph& /*graph*/, - const char* /*json_config*/) { - if (1) { // suppress dead code warning - my_abort(); - } - return DllSafe>>(); -} - -} // namespace onnxruntime_vitisai_ep diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 32ee6ff652aac..5f20b32cd6dc4 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -15,8 +15,6 @@ #include "core/session/custom_ops.h" #include "core/session/inference_session.h" -#include "onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h" - using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -24,8 +22,7 @@ namespace onnxruntime { constexpr const char* VITISAI = "VITISAI"; static vaip_core::DllSafe>> compile_onnx_model( - const onnxruntime::GraphViewer& graph_viewer, - const logging::Logger& logger, const char* json_config) { + const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { #ifndef _WIN32 auto model_path = graph_viewer.ModelPath().ToPathString(); #else @@ -33,12 +30,13 @@ static vaip_core::DllSafe strconverter; auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); #endif - return onnxruntime_vitisai_ep::compile_onnx_model_3(model_path, graph_viewer.GetGraph(), json_config); + return compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options); } + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { - op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), - reinterpret_cast(&info)); + op_kernel_ = + op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast(&info)); } ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } @@ -55,8 +53,7 @@ struct MyCustomOpKernel : OpKernel { void* op_kernel_; }; -VitisAIExecutionProvider::VitisAIExecutionProvider( - const VitisAIExecutionProviderInfo& info) +VitisAIExecutionProvider::VitisAIExecutionProvider(const ProviderOptions& info) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { custom_op_domains_ = initialize_vitisai_ep(); registry_ = std::make_shared(); @@ -77,7 +74,8 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { } } def_builder.Provider(onnxruntime::kVitisAIExecutionProvider); - KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, + std::unique_ptr& out) -> Status { out = std::make_unique(info, *op); return Status::OK(); }; @@ -89,9 +87,8 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return registry_; } -std::vector> -VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const { +std::vector> VitisAIExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { if (graph.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; @@ -100,9 +97,7 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // Only compiling a model once is currently supported return {}; } - auto opt_str = info_.get_json_config_str(); // String - execution_providers_ = - std::make_unique(compile_onnx_model(graph, *GetLogger(), opt_str)); + execution_providers_ = std::make_unique(compile_onnx_model(graph, *GetLogger(), info_)); auto result = vaip::GetComputeCapabilityOps(graph, execution_providers_.get(), vitisai_optypes_); size_t index = 0u; for (auto& ep : **execution_providers_) { @@ -112,16 +107,14 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, return result; } -common::Status VitisAIExecutionProvider::Compile( - const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { +common::Status VitisAIExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { for (const auto& fused_node_graph : fused_nodes_and_graphs) { NodeComputeInfo compute_info; const onnx::AttributeProto* attr = graph_utils::GetNodeAttribute(fused_node_graph.fused_node, "index"); assert(attr != nullptr); size_t index = (size_t)attr->i(); - compute_info.create_state_func = [this, index](ComputeContext* context, - FunctionState* state) { + compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; return 0; @@ -129,15 +122,11 @@ common::Status VitisAIExecutionProvider::Compile( compute_info.release_state_func = [](FunctionState state) { if (state) { - delete reinterpret_cast( - state); + delete reinterpret_cast(state); } }; - compute_info.compute_func = [](FunctionState state, const OrtApi* api, - OrtKernelContext* context) { - reinterpret_cast( - state) - ->Compute(api, context); + compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + reinterpret_cast(state)->Compute(api, context); return Status::OK(); }; node_compute_funcs.push_back(compute_info); diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 5bdfc8c18fb6d..e86b53339d4d2 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -4,6 +4,10 @@ #pragma once #include +#include +#include +#include +#include #include "core/framework/execution_provider.h" #include "core/framework/customregistry.h" @@ -18,34 +22,19 @@ class ExecutionProvider; } // namespace vaip_core namespace onnxruntime { -// Information needed to construct execution providers. -struct VitisAIExecutionProviderInfo { - VitisAIExecutionProviderInfo(const ProviderOptions& provider_options); - - const char* get_json_config_str() const { - return json_config_.c_str(); - } - - private: - ProviderOptions provider_options_; - const std::string json_config_; -}; - // Logical device representation. class VitisAIExecutionProvider : public IExecutionProvider { public: - explicit VitisAIExecutionProvider(const VitisAIExecutionProviderInfo& info); + explicit VitisAIExecutionProvider(const ProviderOptions& info); ~VitisAIExecutionProvider() = default; - std::vector> - GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; + std::vector> GetCapability(const onnxruntime::GraphViewer& graph, + const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return 0; } - common::Status Compile( - const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) override; + common::Status Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; private: @@ -54,7 +43,7 @@ class VitisAIExecutionProvider : public IExecutionProvider { using my_ep_uptr_t = std::shared_ptr; // we have to hide the implementation by forward declaration. mutable my_ep_uptr_t execution_providers_; - VitisAIExecutionProviderInfo info_; + ProviderOptions info_; std::vector custom_op_domains_; std::shared_ptr registry_; std::set vitisai_optypes_; diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 763a3efd1b35b..4c416124ca8f2 100755 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -3,56 +3,37 @@ #include "vitisai_provider_factory_creator.h" +#include +#include + #include "vaip/global_api.h" #include "./vitisai_execution_provider.h" #include "core/framework/execution_provider.h" #include "core/session/abi_session_options_impl.h" -#include "nlohmann/json.hpp" -#include -#include -#include +#include "core/providers/shared_library/provider_host_api.h" using namespace onnxruntime; -using json = nlohmann::json; namespace onnxruntime { -static std::string ConfigToJsonStr(const std::unordered_map& config) { - const auto& filename = config.at("config_file"); - std::ifstream f(filename); - json data = json::parse(f); - for (const auto& entry : config) { - data[entry.first] = entry.second; - } - return data.dump(); -} - -VitisAIExecutionProviderInfo::VitisAIExecutionProviderInfo(const ProviderOptions& provider_options) : provider_options_(provider_options), json_config_{ConfigToJsonStr(provider_options)} {} - struct VitisAIProviderFactory : IExecutionProviderFactory { - VitisAIProviderFactory(const VitisAIExecutionProviderInfo& info) : info_(info) {} + VitisAIProviderFactory(const ProviderOptions& info) : info_(info) {} ~VitisAIProviderFactory() = default; std::unique_ptr CreateProvider() override; private: - VitisAIExecutionProviderInfo info_; + ProviderOptions info_; }; std::unique_ptr VitisAIProviderFactory::CreateProvider() { return std::make_unique(info_); } -std::shared_ptr -CreateExecutionProviderFactory_VITISAI(const VitisAIExecutionProviderInfo& info) { - initialize_vitisai_ep(); - return std::make_shared(info); -} - -std::shared_ptr VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) { +std::shared_ptr VitisAIProviderFactoryCreator::Create( + const ProviderOptions& provider_options) { initialize_vitisai_ep(); - auto info = VitisAIExecutionProviderInfo{provider_options}; - return std::make_shared(info); + return std::make_shared(provider_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h b/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h index 9e0583275d1b6..9bb7cfa062a0f 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h @@ -9,9 +9,6 @@ #include "core/framework/provider_options.h" namespace onnxruntime { - -struct VitisAIExecutionProviderInfo; - struct VitisAIProviderFactoryCreator { static std::shared_ptr Create(const ProviderOptions& provider_options); }; diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index a5bcbce89bac6..6827f2c9dfd91 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -85,13 +85,6 @@ struct OrtStatus { #define BACKEND_TVM "" #endif -#if USE_VITISAI -#define BACKEND_VITISAI "-VITISAI" -#include "core/providers/vitisai/vitisai_execution_provider.h" -#else -#define BACKEND_VITISAI "" -#endif - #if USE_OPENBLAS #define BACKEND_OPENBLAS "-OPENBLAS" #else @@ -451,9 +444,6 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(c std::shared_ptr CreateExecutionProviderFactory_Tvm(const tvm::TvmEPOptions& info); std::shared_ptr CreateExecutionProviderFactory_Tvm(const char* params); #endif -std::shared_ptr CreateExecutionProviderFactory_VITISAI(const char* backend_type, int device_id, - const char* export_runtime_module, - const char* load_runtime_module); std::shared_ptr CreateExecutionProviderFactory_ACL(int use_arena); std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); From cbad4fe49bfada781059659f555fcde49fbae37f Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 14 Dec 2023 16:15:07 -0800 Subject: [PATCH 04/11] Update absl and googletest (#18827) ### Description Update absl and googletest to their latest version to include some cmake changes: 1. A googletest's cmake change that will allow using external absl and re2. 2. Nullability enhancements that will allow our clang-based static analysis detecting many kinds of null pointer errors. ### Motivation and Context To fix a C4744 link warning in our Windows pipelines. ``` LINK : warning C4744: 'static char const absl::lts_20230802::base_internal::FastTypeTag::dummy_var' has different type in 'd:\a\_work\_temp\abseil_cpp\abseil-cpp-20230802.0\absl\flags\parse.cc' and 'd:\a\_work\1\b\relwithdebinfo\_deps\googletest-src\googletest\src\gtest-all.cc': 'signed char' and 'unsigned char' [D:\a\_work\1\b\RelWithDebInfo\onnxruntime_mlas_test.vcxproj] LINK : warning C4744: 'static char const absl::lts_20230802::base_internal::FastTypeTag,class std::allocator > >::dummy_var' has different type in 'd:\a\_work\_temp\abseil_cpp\abseil-cpp-20230802.0\absl\flags\parse.cc' and 'd:\a\_work\1\b\relwithdebinfo\_deps\googletest-src\googletest\src\gtest-all.cc': 'signed char' and 'unsigned char' [D:\a\_work\1\b\RelWithDebInfo\onnxruntime_mlas_test.vcxproj] LINK : warning C4744: 'static char const absl::lts_20230802::base_internal::FastTypeTag,class std::allocator > >::dummy_var' has different type in 'd:\a\_work\_temp\abseil_cpp\abseil-cpp-20230802.0\absl\flags\internal\usage.cc' and 'd:\a\_work\1\b\relwithdebinfo\_deps\googletest-src\googletest\src\gtest-all.cc': 'signed char' and 'unsigned char' [D:\a\_work\1\b\RelWithDebInfo\onnxruntime_mlas_test.vcxproj] LINK : warning C4744: 'static char const absl::lts_20230802::base_internal::FastTypeTag::dummy_var' has different type in 'd:\a\_work\_temp\abseil_cpp\abseil-cpp-20230802.0\absl\flags\internal\flag.cc' and 'd:\a\_work\1\b\relwithdebinfo\_deps\googletest-src\googletest\src\gtest-all.cc': 'signed char' and 'unsigned char' [D:\a\_work\1\b\RelWithDebInfo\onnxruntime_mlas_test.vcxproj] LINK : warning C4744: 'static char const absl::lts_20230802::base_internal::FastTypeTag,class std::allocator > >::dummy_var' has different type in 'd:\a\_work\_temp\abseil_cpp\abseil-cpp-20230802.0\absl\flags\internal\flag.cc' and 'd:\a\_work\1\b\relwithdebinfo\_deps\googletest-src\googletest\src\gtest-all.cc': 'signed char' and 'unsigned char' [D:\a\_work\1\b\RelWithDebInfo\onnxruntime_mlas_test.vcxproj] LINK : warning C4744: 'static char const absl::lts_20230802::base_internal::FastTypeTag::dummy_var' has different type in 'd:\a\_work\_temp\abseil_cpp\abseil-cpp-20230802.0\absl\flags\internal\flag.cc' and 'd:\a\_work\1\b\relwithdebinfo\_deps\googletest-src\googletest\src\gtest-all.cc': 'signed char' and 'unsigned char' [D:\a\_work\1\b\RelWithDebInfo\onnxruntime_mlas_test.vcxproj] ``` --- cgmanifests/generated/cgmanifest.json | 4 ++-- cmake/deps.txt | 4 ++-- .../github/azure-pipelines/templates/download-deps.yml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 5a016717f7d1e..137ea8a50c011 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "3abf3298b6b43acc8556b1342ffb6de4a85fb30f", + "commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" @@ -126,7 +126,7 @@ "component": { "type": "git", "git": { - "commitHash": "b3a9ba2b8e975550799838332803d468797ae2e1", + "commitHash": "530d5c8c84abd2a46f38583ee817743c9b3a42b4", "repositoryUrl": "https://github.com/google/googletest.git" }, "comments": "googletest" diff --git a/cmake/deps.txt b/cmake/deps.txt index 8a9ccef6f8181..ff07803013071 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/3abf3298b6b43acc8556b1342ffb6de4a85fb30f.zip;d6da50a47c1268b5d6d5405b7fc21258ccd84d31 +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -27,7 +27,7 @@ fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 -googletest;https://github.com/google/googletest/archive/b3a9ba2b8e975550799838332803d468797ae2e1.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc +googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 9ef1aed55d58c..537175f6bec73 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.128 + version: 1.0.129 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.128 + version: 1.0.129 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 5eda79bdd3f138d599d5d0dda75b76096ea62a93 Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 15 Dec 2023 13:32:19 +0800 Subject: [PATCH 05/11] Improve perf for stage3 training (#18099) ### Improve perf for stage3 training - first wave Port existing PythonOp/PythonOpGrad python runner to C++, also introduce an unsafe run mode (to skip inplace, save for backward, materrialized grad detection on the fly). This reduce the overhead from XX~XXX us to X ~ lower end of XX us . In LLAMA2 7B training with 8x32GV100, we have observed 6.7% gains over PyTorch. (1.59 v.s. 1.49it/s) Peak memory also dropped from 31GB to 28GB. ### Motivation and Context --- .../torch/custom_function_register.cc | 64 +- .../torch/custom_function_register.h | 30 +- .../core/framework/torch/torch_proxy.cc | 285 +++---- .../core/framework/torch/torch_proxy.h | 4 +- .../core/graph/gradient_builder.cc | 1 + .../core/graph/training_op_defs.cc | 18 + .../python/orttraining_pybind_state.cc | 6 +- .../ortmodule/_custom_autograd_function.py | 5 +- .../_custom_autograd_function_exporter.py | 8 +- .../_custom_autograd_function_runner.py | 707 ------------------ .../ortmodule/_zero_stage3_compatibility.py | 58 +- .../cpu/torch_interop_utils/ctx_pool.cc | 23 + .../cpu/torch_interop_utils/ctx_pool.h | 96 +++ .../torch_interop_utils/custom_function_bw.cc | 174 +++++ .../torch_interop_utils/custom_function_bw.h | 16 + .../torch_interop_utils/custom_function_fw.cc | 516 +++++++++++++ .../torch_interop_utils/custom_function_fw.h | 16 + .../custom_function_shared.cc | 213 ++++++ .../custom_function_shared.h | 89 +++ .../cpu/torch_interop_utils/fake_ctx.py | 13 + .../cpu/torch_interop_utils/setup.py | 21 +- .../torch_interop_utils.cc | 189 +---- .../python/training/utils/__init__.py | 9 + .../utils/hooks/_zero_offload_subscriber.py | 76 +- .../python/training/utils/torch_io_helper.py | 4 + .../training/utils/torch_profile_utils.py | 28 + .../orttraining_test_ortmodule_autograd.py | 15 +- .../torch_custom_function_kernel_base.cc | 13 +- .../torch/torch_custom_function_kernel_base.h | 4 + setup.py | 2 +- 30 files changed, 1520 insertions(+), 1183 deletions(-) delete mode 100644 orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.cc create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.h create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.h create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.h create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.cc create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.h create mode 100644 orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/fake_ctx.py create mode 100644 orttraining/orttraining/python/training/utils/torch_profile_utils.py diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.cc b/orttraining/orttraining/core/framework/torch/custom_function_register.cc index 1a51da3daa27f..9ab3fdb0b7c0a 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.cc +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.cc @@ -88,11 +88,14 @@ void OrtTorchFunctionPool::RegisterTorchAutogradFunction( PythonObjectPtr forward(PyObject_GetAttrString(obj, "apply"), PythonObjectDeleter); PythonObjectPtr backward(PyObject_GetAttrString(obj, "backward"), PythonObjectDeleter); + PythonObjectPtr unsafe_forward(PyObject_GetAttrString(obj, "forward"), PythonObjectDeleter); ORT_ENFORCE(forward.get(), "apply attribute not found when registering ", key); ORT_ENFORCE(backward.get(), "backward attribute not found when registering ", key); + ORT_ENFORCE(unsafe_forward.get(), "forward attribute not found when registering ", key); RegisterEntry(mutex_, key, forward.get(), forward_core_pool_); RegisterEntry(mutex_, key, backward.get(), backward_core_pool_); + RegisterEntry(mutex_, key, unsafe_forward.get(), unsafe_forward_core_pool_); } void OrtTorchFunctionPool::RegisterShapeInferenceFunction(const std::string& key, @@ -105,46 +108,27 @@ void OrtTorchFunctionPool::RegisterInputAliasFunction(const std::string& key, RegisterEntry(mutex_, key, obj, input_alias_function_pool_); } -static void RegisterEntry( - std::mutex& mutex, - PyObject* obj, - PythonObjectPtr& storage) { - std::lock_guard lock(mutex); - // Basic checks. - ORT_ENFORCE(obj, "Cannot register NULL PyObject*."); - - // Skip registration if storage already stores a Python object. - if (storage.get() != nullptr) { - return; - } - - // Own the Python object. - Py_INCREF(obj); - PythonObjectPtr ptr(obj, PythonObjectDeleter); - - // If an obj has been registered, this old ownership is automatically released - // after this move-assignment. Then, the "storage" owns the new object. - storage = std::move(ptr); +void OrtTorchFunctionPool::RegisterForwardRunner(size_t function_address) { + void* p_forward_runner_func = reinterpret_cast(function_address); + forward_runner_ = reinterpret_cast(p_forward_runner_func); } -void OrtTorchFunctionPool::RegisterForwardRunner(PyObject* obj) { - RegisterEntry(mutex_, obj, forward_runner_); +void OrtTorchFunctionPool::RegisterBackwardRunner(size_t function_address) { + void* p_backward_runner_func = reinterpret_cast(function_address); + backward_runner_ = reinterpret_cast(p_backward_runner_func); } -void OrtTorchFunctionPool::RegisterBackwardRunner(PyObject* obj) { - RegisterEntry(mutex_, obj, backward_runner_); -} +CustomFunctionRunnerType OrtTorchFunctionPool::GetForwardRunner() { + ORT_ENFORCE(forward_runner_, + "Forward runner cannot be NULL. Did you forget to register it by calling RegisterForwardRunner(...)?"); -PyObject* OrtTorchFunctionPool::GetForwardRunner() { - std::lock_guard lock(mutex_); - ORT_ENFORCE(forward_runner_.get(), "Forward runner cannot be NULL. Do you forget register it by calling RegisterForwardRunner(...)?"); - return forward_runner_.get(); + return forward_runner_; } -PyObject* OrtTorchFunctionPool::GetBackwardRunner() { - std::lock_guard lock(mutex_); - ORT_ENFORCE(backward_runner_.get(), "backward runner cannot be NULL. Do you forget register it by calling RegisterBackwardRunner(...)?"); - return backward_runner_.get(); +CustomFunctionRunnerType OrtTorchFunctionPool::GetBackwardRunner() { + ORT_ENFORCE(backward_runner_, + "backward runner cannot be NULL. Did you forget to register it by calling RegisterBackwardRunner(...)?"); + return backward_runner_; } PyObject* OrtTorchFunctionPool::GetForwardCore(const std::string& key) { @@ -163,6 +147,14 @@ PyObject* OrtTorchFunctionPool::GetBackwardCore(const std::string& key) { return iter->second.get(); } +PyObject* OrtTorchFunctionPool::GetUnsafeForwardCore(const std::string& key) { + ORT_ENFORCE(!key.empty(), "Cannot be empty string."); + std::lock_guard lock(mutex_); + auto iter = unsafe_forward_core_pool_.find(key); + ORT_ENFORCE(iter != unsafe_forward_core_pool_.end(), "No unsafe forward registered for ", key); + return iter->second.get(); +} + std::optional OrtTorchFunctionPool::TryGettingShapeInferenceFunction(const std::string& key) { ORT_ENFORCE(!key.empty(), "Cannot be empty string."); std::lock_guard lock(mutex_); @@ -201,10 +193,9 @@ int64_t OrtTorchFunctionPool::RegisterContext(PyObject* autograd_context) { autograd_context, "autograd_context_register"); ORT_ENFORCE(autograd_context, "Cannot register NULL autograd context."); - Py_INCREF(autograd_context); func_context_pool_.insert({index_, PythonObjectPtr(autograd_context, PythonObjectDeleter)}); - // We don't need increase the context refcnt because PyTorch already did it during .apply(). + return index_; } @@ -227,14 +218,13 @@ PyObject* OrtTorchFunctionPool::GetContext(int64_t context_index) { } void OrtTorchFunctionPool::UnRegisterGlobalFunctions() { - forward_runner_.reset(); - backward_runner_.reset(); func_context_pool_.clear(); } void OrtTorchFunctionPool::UnRegisterModelSpecificFunctions() { forward_core_pool_.clear(); backward_core_pool_.clear(); + unsafe_forward_core_pool_.clear(); shape_inference_function_pool_.clear(); input_alias_function_pool_.clear(); miscellaneous_const_input_pool_.clear(); diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.h b/orttraining/orttraining/core/framework/torch/custom_function_register.h index d51cc7dadc1af..67a991ea2cce3 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.h +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.h @@ -13,6 +13,16 @@ namespace onnxruntime { namespace language_interop_ops { namespace torch { +typedef std::vector (*CustomFunctionRunnerType)(const char* func_name_char, + void* callback, + const std::vector& requires_grads, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& tensor_args); + class OrtTorchFunctionPool final { public: static OrtTorchFunctionPool& GetInstance() { @@ -34,6 +44,9 @@ class OrtTorchFunctionPool final { // 2. Caller of GetBackwardCore should not decrease the reference count of the returned object. PyObject* GetBackwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOpGrad. + // Return a borrowed reference to the stored Python function running in safe mode. + PyObject* GetUnsafeForwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOp. + // Shape inference function is used to infer output shape of a PythonOp. void RegisterShapeInferenceFunction(const std::string& key, PyObject* obj); // Return a borrowed reference to the stored Python function, if it exists; otherwise, return nullptr. @@ -67,15 +80,15 @@ class OrtTorchFunctionPool final { // ForwardRunner/BackwardRunner are "glue" codes written in Python that interacting // with C++ kernels during Python function invoking. // This function creates new ownership to "obj". - void RegisterForwardRunner(PyObject* obj); + void RegisterForwardRunner(size_t function_address); // This function creates new ownership to "obj". - void RegisterBackwardRunner(PyObject* obj); - // Return a borrowed reference to a Python function, which + void RegisterBackwardRunner(size_t function_address); + // Return a borrowed reference to a c++ function, which // is responsible for executing autograd.Function.apply. - PyObject* GetForwardRunner(); - // Return a borrowed reference to a Python function, which + CustomFunctionRunnerType GetForwardRunner(); + // Return a borrowed reference to a c++ function, which // is responsible for executing autograd.Function.apply. - PyObject* GetBackwardRunner(); + CustomFunctionRunnerType GetBackwardRunner(); // The reason we provide this unregister api is: // A static OrtTorchFunctionPool instance will be destructed after @@ -97,11 +110,12 @@ class OrtTorchFunctionPool final { void UnRegisterGlobalFunctions(); void UnRegisterModelSpecificFunctions(); - PythonObjectPtr forward_runner_; - PythonObjectPtr backward_runner_; + CustomFunctionRunnerType forward_runner_; + CustomFunctionRunnerType backward_runner_; std::unordered_map forward_core_pool_; std::unordered_map backward_core_pool_; + std::unordered_map unsafe_forward_core_pool_; std::unordered_map shape_inference_function_pool_; std::unordered_map input_alias_function_pool_; diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.cc b/orttraining/orttraining/core/framework/torch/torch_proxy.cc index f36f913366a37..1cd01ae16deea 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.cc +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.cc @@ -12,7 +12,10 @@ namespace onnxruntime::language_interop_ops::torch { -void PythonObjectDeleter(PyObject* ptr) { Py_XDECREF(ptr); }; +void PythonObjectDeleter(PyObject* ptr) { + GilGuard gil; + Py_XDECREF(ptr); +} PyObject* Ort_PyTuple_New(const size_t len, const std::string& log_tag) { PyObject* item = PyTuple_New(len); @@ -20,34 +23,11 @@ PyObject* Ort_PyTuple_New(const size_t len, const std::string& log_tag) { return item; } -void Ort_PyTuple_SetItem_Incref(PyObject* py_tuple, size_t index, PyObject* item, const std::string& log_tag) { - RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); - Py_INCREF(item); - PyTuple_SetItem(py_tuple, index, item); -} - void Ort_PyTuple_SetItem_NoIncref(PyObject* py_tuple, size_t index, PyObject* item, const std::string& log_tag) { RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); PyTuple_SetItem(py_tuple, index, item); } -PyObject* Ort_PyList_New(const size_t len, const std::string& log_tag) { - PyObject* item = PyList_New(len); - RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); - return item; -} - -void Ort_PyList_SetItem_Incref(PyObject* py_list, size_t index, PyObject* item, const std::string& log_tag) { - RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); - Py_INCREF(item); - PyList_SetItem(py_list, index, item); -} - -void Ort_PyList_SetItem_NoIncref(PyObject* py_list, size_t index, PyObject* item, const std::string& log_tag) { - RefCountTracker::GetInstance().TrackPyObject(RefCountTracker::ObjCategory::PythonCallArgs, item, log_tag); - PyList_SetItem(py_list, index, item); -} - void CheckArguments( const size_t len, const std::vector& requires_grads, @@ -92,87 +72,51 @@ void CheckArguments( // len: the number of input arguments. // tensor_indices: if tensor_indices[i] is j, // then the j-th input argument should be a tensor. -PyObject* CreateTensorFlags( - const size_t len, - const std::vector& tensor_indices) { - PyObject* flags = Ort_PyList_New(len, "tensor_flags_list"); - - // First we fill the list with 0. Later we will - // assign 1's to tensors' corresponding positions. - for (size_t i = 0; i < len; ++i) { - PyObject* zero = PyLong_FromLong(0); - Ort_PyList_SetItem_NoIncref(flags, i, zero, std::to_string(__LINE__)); - } - +std::vector CreateTensorFlags(const size_t len, const std::vector& tensor_indices) { + std::vector flags(len, 0); for (const auto i : tensor_indices) { - PyObject* one = PyLong_FromLong(1); - Ort_PyList_SetItem_NoIncref(flags, i, one, std::to_string(__LINE__)); + flags[i] = 1; } return flags; } -// flags[i] corresponds to the i-th input of apply/backward. -PyObject* CreateRequiresGradFlags( - const std::vector& requires_grads) { - PyObject* flags = Ort_PyList_New(requires_grads.size(), "require_grads_list"); - for (size_t i = 0; i < requires_grads.size(); ++i) { - PyObject* value; - if (requires_grads.at(i) != 0) { - value = Py_True; - } else { - value = Py_False; - } - Ort_PyList_SetItem_Incref(flags, i, value, std::to_string(__LINE__)); - } - return flags; -} - -PyObject* CreateInplaceMap( - const std::vector& inplace_map) { - PyObject* inplace_map_obj = Ort_PyList_New(inplace_map.size(), "inplace_map"); - - for (size_t output_index = 0; output_index < inplace_map.size(); ++output_index) { - PyObject* input_index = PyLong_FromLong(inplace_map[output_index]); - Ort_PyList_SetItem_NoIncref(inplace_map_obj, output_index, input_index, std::to_string(__LINE__)); - } - - return inplace_map_obj; -} - -void InvokeRunner( - PyObject* callback_runner, - PyObject* args, - bool is_training_mode, - void** diff_ctx, - std::vector& returned_ortvalues) { - PythonObjectPtr result_ptr(PyObject_CallObject(callback_runner, args), PythonObjectDeleter); - - if (PyErr_Occurred()) { - PyErr_Print(); - ORT_THROW("Python function execution fails with the above information."); - } - - ORT_ENFORCE(PyTuple_Check(result_ptr.get()), "Python function must return a tuple."); - +void ProcessReturnValues(std::vector& results, + bool is_training_mode, + bool safe_run_mode_enabled, + void** diff_ctx, + std::vector& returned_ortvalues) { size_t i = 0; if (diff_ctx) { // Assume that the first input element in the returned tuple is autograd context // from Pytorch. - PyObject* py_obj = PyTuple_GetItem(result_ptr.get(), 0); + ORT_ENFORCE(results.size() > 0, "The returned tuple should have at least one element."); + PyObject* py_obj = results[0]; if (is_training_mode) { if (py_obj == Py_None) { LOGS_DEFAULT(VERBOSE) << "Under training mode, autograd context found to be Py_None."; } else { + GilGuard guard; + const auto refcnt = Py_REFCNT(py_obj); - // We don't need do ref increase here because, python returns tensor.grad_fn as part of - // tuple, who increased the refcnt already (and tensor persist until the backward kernels completed). - // Pytorch also increases refcnt before apply() return, so we should expect refcount >= 2. - // We say "at least" 2 because user could increase the context refcnt as well in their autograd forward() - // and backward() functions. - ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, "."); - if (refcnt > 2) { - LOGS_DEFAULT(VERBOSE) << "Autograd context refcnt > 2, refcnt: " << refcnt; + if (safe_run_mode_enabled) { + // For safe_run_mode_enabled, we expect refcnt >= 2. + // 1. shared_ptr is maintained in torch_interop_utils::PyNodeSharedPointerPool. PyNode is owning + // the context, e.g. THPFunction*. + // 2. results own another reference to the context, while the ownership will be ended after `Invoke` completed. + ORT_ENFORCE(refcnt >= 2, "Ref count of context should be 2, but actually it's ", refcnt, "."); + + // Own one reference!!! + Py_INCREF(py_obj); + + if (refcnt > 2) { + LOGS_DEFAULT(VERBOSE) << "Autograd context refcnt > 2, refcnt: " << refcnt; + } + } else { + ORT_ENFORCE(refcnt == 1, "Ref count of context should be 1, but actually it's ", refcnt, "."); + + // Own one reference!!! + Py_INCREF(py_obj); } } } else { @@ -184,12 +128,13 @@ void InvokeRunner( // i is 1 if the first element is autograd context. Otherwise, i is 0, so we read from the // first element. - for (; i < static_cast(PyTuple_Size(result_ptr.get())); ++i) { - PyObject* dl_tensor_pointer = PyTuple_GetItem(result_ptr.get(), i); + for (; i < results.size(); ++i) { + PyObject* dl_tensor_pointer = results[i]; if (dl_tensor_pointer == Py_None) { OrtValue empty_ort_value; returned_ortvalues.push_back(empty_ort_value); } else { + GilGuard guard; ORT_ENFORCE(Py_REFCNT(dl_tensor_pointer) == 1, "Ref count of dl_tensor_pointer should be 1."); // Todo (pengwa): be noted we did not pass whether tensor is bool or not. // Currently we assume we don't pass boolean data. @@ -198,73 +143,44 @@ void InvokeRunner( } } -PythonObjectPtr CreatePythonCallArguments( - PyObject* callback, - const size_t len, - const std::vector& requires_grads, - const std::vector>& tensor_args, - const std::vector& tensor_indices, - const std::vector& obj_args, - const std::vector& obj_indices, - const bool is_training_mode, - const std::vector& inplace_map, - const std::string& invoke_id, - const std::string& func_name) { - ORT_ENFORCE(PyCallable_Check(callback), "Forward callback is not callable."); - // The number of variables before those of - // autograd.Function.apply and autograd.Function.backward. - // The extra variables are used to configure the launch - // forward and backward runners. - constexpr int64_t num_control_args = 7; - - // All arguments created for Python call will be destroyed along with PythonObjectPtr. - PythonObjectPtr args(Ort_PyTuple_New(num_control_args + len, "forward_arguments_tuple"), PythonObjectDeleter); - PyObject* tensor_flags = CreateTensorFlags(len, tensor_indices); - PyObject* requires_grad_flags = CreateRequiresGradFlags(requires_grads); - - Ort_PyTuple_SetItem_Incref(args.get(), 0, callback, "callback_function"); - Ort_PyTuple_SetItem_NoIncref(args.get(), 1, requires_grad_flags, "requires_grad_flags"); - Ort_PyTuple_SetItem_NoIncref(args.get(), 2, tensor_flags, "tensor_flags"); - PyObject* is_training_mode_arg = is_training_mode ? Py_True : Py_False; - Ort_PyTuple_SetItem_Incref(args.get(), 3, is_training_mode_arg, "is_training_mode"); - - PyObject* inplace_map_arg = CreateInplaceMap(inplace_map); - Ort_PyTuple_SetItem_NoIncref(args.get(), 4, inplace_map_arg, "inplace_map"); - - PyObject* kernel_invoke_id_arg = PyBytes_FromStringAndSize(invoke_id.c_str(), invoke_id.size()); - Ort_PyTuple_SetItem_NoIncref(args.get(), 5, kernel_invoke_id_arg, "kernel_invoke_id_arg"); - - PyObject* func_name_arg = PyBytes_FromStringAndSize(func_name.c_str(), func_name.size()); - Ort_PyTuple_SetItem_NoIncref(args.get(), 6, func_name_arg, "func_name_arg"); +void PrepareCallArguments(const std::vector>& tensor_args, + const std::vector& tensor_indices, + const std::vector& obj_args, + const std::vector& obj_indices, + std::vector& args, + std::vector& tensor_flags) { + const size_t len = tensor_args.size() + obj_args.size(); + tensor_flags = CreateTensorFlags(len, tensor_indices); + args.resize(len, nullptr); // Tensor inputs to call autograd.Function.apply or autograd.Function.backward. - for (size_t i = 0; i < tensor_args.size(); ++i) { - if (!tensor_args[i].has_value()) { - Ort_PyTuple_SetItem_Incref(args.get(), num_control_args + tensor_indices[i], Py_None, - "non_tensor_args"); - continue; - } + { + GilGuard guard; + for (size_t i = 0; i < tensor_args.size(); ++i) { + if (!tensor_args[i].has_value()) { + Py_INCREF(Py_None); + args[tensor_indices[i]] = Py_None; + continue; + } - // Wrap with DLPack, then transfer to Python for its release. - PyObject* dl_tensor = training::framework::torch::ToDlpack(tensor_args[i].value()); - Ort_PyTuple_SetItem_NoIncref(args.get(), num_control_args + tensor_indices[i], dl_tensor, - "dltensor"); - } + // Wrap with DLPack, then transfer to Python for its release. + PyObject* dl_tensor = training::framework::torch::ToDlpack(tensor_args[i].value()); + args[tensor_indices[i]] = dl_tensor; + } - // Non-tensor inputs to call autograd.Function.apply or autograd.Function.backward. - for (size_t i = 0; i < obj_args.size(); ++i) { - PyObject* pyobj = reinterpret_cast(obj_args[i]); - Ort_PyTuple_SetItem_Incref(args.get(), num_control_args + obj_indices[i], pyobj, - "const_args"); + // Non-tensor inputs to call autograd.Function.apply or autograd.Function.backward. + for (size_t i = 0; i < obj_args.size(); ++i) { + PyObject* pyobj = reinterpret_cast(obj_args[i]); + Py_INCREF(pyobj); + args[obj_indices[i]] = pyobj; + } } - - return args; } void Invoke( const std::string& func_name, - PyObject* runner, - PyObject* callback, + const CustomFunctionRunnerType& runner, + void* callback, const std::vector& requires_grads, const std::vector>& tensor_args, const std::vector& tensor_indices, @@ -273,30 +189,40 @@ void Invoke( const bool is_training_mode, const std::vector& inplace_map, const std::string& invoke_id, + bool safe_run_mode_enabled, void** diff_ctx, std::vector& returned_ortvalues) { const auto len = tensor_args.size() + obj_args.size(); CheckArguments(len, requires_grads, tensor_args, tensor_indices, obj_args, obj_indices); - RefCountTracker::GetInstance().Reset(); - { - PythonObjectPtr args = CreatePythonCallArguments( - callback, - len, - requires_grads, - tensor_args, - tensor_indices, - obj_args, - obj_indices, - is_training_mode, - inplace_map, - invoke_id, - func_name); - - RefCountTracker::GetInstance().DumpDetails("Before Invoke Python Call"); - InvokeRunner(runner, args.get(), is_training_mode, diff_ctx, returned_ortvalues); + std::vector args; + std::vector tensor_flags; + PrepareCallArguments(tensor_args, tensor_indices, obj_args, obj_indices, args, tensor_flags); + + std::vector results; + + std::vector raii_args; + raii_args.reserve(args.size()); + for (auto& arg : args) { + raii_args.emplace_back(arg, PythonObjectDeleter); + } + + results = runner(func_name.c_str(), + callback, + requires_grads, + tensor_flags, + is_training_mode, + inplace_map, + invoke_id.c_str(), + safe_run_mode_enabled, + args); + + std::vector raii_results; + raii_results.reserve(results.size()); + for (auto& arg : results) { + raii_results.emplace_back(arg, PythonObjectDeleter); } - RefCountTracker::GetInstance().DumpDetails("After Python Call Completed"); + ProcessReturnValues(results, is_training_mode, safe_run_mode_enabled, diff_ctx, returned_ortvalues); } void TorchProxy::Forward( @@ -310,6 +236,7 @@ void TorchProxy::Forward( const bool is_training_mode, const std::vector& inplace_map, const std::string& invoke_id, + bool safe_run_mode_enabled, void** diff_ctx, std::vector& returned_ortvalues) { // Semantically, this lock uniquely takes the ownership of TorchProxy @@ -317,12 +244,12 @@ void TorchProxy::Forward( // can be run at one time. std::lock_guard lock(mutex_); // Python-related calls should happen only if guard is alive. - GilGuard guard; - auto runner = OrtTorchFunctionPool::GetInstance().GetForwardRunner(); + CustomFunctionRunnerType runner = OrtTorchFunctionPool::GetInstance().GetForwardRunner(); + Invoke( func_name, runner, - reinterpret_cast(callback), + callback, requires_grads, tensor_args, tensor_indices, @@ -331,6 +258,7 @@ void TorchProxy::Forward( is_training_mode, inplace_map, invoke_id, + safe_run_mode_enabled, diff_ctx, returned_ortvalues); } @@ -344,30 +272,30 @@ void TorchProxy::Backward( const std::vector& obj_indices, const std::vector& inplace_map, const std::string& invoke_id, + bool safe_run_mode_enabled, std::vector& returned_ortvalues) { // Semantically, this lock uniquely takes the ownership of TorchProxy // so that there will be only one of TorchProxy::Forward TorchProxy::Backward // can be run at one time. std::lock_guard lock(mutex_); - // Python-related calls should happen only if guard is alive. - GilGuard guard; - auto runner = OrtTorchFunctionPool::GetInstance().GetBackwardRunner(); - + CustomFunctionRunnerType runner = OrtTorchFunctionPool::GetInstance().GetBackwardRunner(); // Pass all zero since backward inputs don't require gradients. const auto all_input_count = tensor_args.size() + obj_args.size(); const std::vector requires_grads(all_input_count, 0); + Invoke( func_name, runner, - reinterpret_cast(callback), + callback, requires_grads, tensor_args, tensor_indices, obj_args, obj_indices, - true /* is_training_mode */, + false /* is_training_mode */, inplace_map, invoke_id, + safe_run_mode_enabled, nullptr /* context to store */, returned_ortvalues); } @@ -377,6 +305,9 @@ void TorchProxy::RunInputAliasFunction( const std::string& node_proto_str, std::vector& fw_output_to_input_alias_map, std::vector& bw_output_to_input_alias_map) { + // Python-related calls should happen only if guard is alive. + GilGuard guard; + PyObject* input_alias_func = reinterpret_cast(input_alias_function); ORT_ENFORCE(PyCallable_Check(input_alias_func), "input_alias_func is not callable."); diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.h b/orttraining/orttraining/core/framework/torch/torch_proxy.h index 1d5cc1dd69095..450a5048aea44 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.h +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.h @@ -50,6 +50,7 @@ class TorchProxy { const bool is_training_mode, const std::vector& inplace_map, const std::string& invoke_id, + bool safe_run_mode_enabled, void** diff_ctx, std::vector& returned_ortvalues); @@ -62,7 +63,8 @@ class TorchProxy { const std::vector& obj_indices, const std::vector& inplace_map, const std::string& invoke_id, - std::vector& return_args); + bool safe_run_mode_enabled, + std::vector& returned_ortvalues); /** * @brief Run given function to get output to input reuse map. diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 755a8e49d9d12..e675b55c8af8f 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1804,6 +1804,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) { ORT_ENFORCE(utils::HasString(src_attrs.at("func_name"))); attrs.push_back(MakeAttribute("func_name", src_attrs.at("func_name").s())); attrs.push_back(MakeAttribute("output_convention", src_attrs.at("input_convention").s())); + attrs.push_back(MakeAttribute("safe_run_mode", src_attrs.at("safe_run_mode").i())); // input_tensor_types[i] store the type of autograd.Function.apply's ith output. // Note that PythonOpGrad's 0-th input is the Python context generated by PythonOp. diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 8d3f76be20c65..a62ca611b8e7e 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3938,6 +3938,15 @@ Return true if all elements are true and false otherwise. "comment", "comment only for debugging purposes.", AttributeProto::STRING, false) + .Attr( + "safe_run_mode", + "Indicate if the function is running in safe mode or not. " + "Safe mode support common use cases of PyTorch ctx for example, save for backward, mark as dirty," + "or materialize gradient. In this mode, inplace operation is detected on the fly. " + "Unsafe mode is used to run the function faster not considering the above ctx usage." + "Additional requirement running in this mode: provide correct input alias map.", + AttributeProto::INT, + static_cast(1)) .TypeConstraint( "T", OpSchema::all_tensor_types(), @@ -4096,6 +4105,15 @@ Return true if all elements are true and false otherwise. "comment only for debugging purposes.", AttributeProto::STRING, false) + .Attr( + "safe_run_mode", + "Indicate if the function is running in safe mode or not. " + "Safe mode support common use cases of PyTorch ctx for example, save for backward, mark as dirty," + "or materialize gradient. In this mode, inplace operation is detected on the fly. " + "Unsafe mode is used to run the function faster not considering the above ctx usage." + "Additional requirement running in this mode: provide correct input alias map.", + AttributeProto::INT, + static_cast(1)) .TypeConstraint( "T", OpSchema::all_tensor_types(), diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index a5f46d88e4e8b..0c2bfa19e1671 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -316,16 +316,18 @@ void addObjectMethodsForTraining(py::module& m) { m.def("register_forward_runner", [](py::object obj) -> void { #ifdef ENABLE_TRAINING_TORCH_INTEROP + size_t function_address = py::cast(obj); auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); - pool.RegisterForwardRunner(obj.ptr()); + pool.RegisterForwardRunner(function_address); #else ORT_UNUSED_PARAMETER(obj); #endif }); m.def("register_backward_runner", [](py::object obj) -> void { #ifdef ENABLE_TRAINING_TORCH_INTEROP + size_t function_address = py::cast(obj); auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); - pool.RegisterBackwardRunner(obj.ptr()); + pool.RegisterBackwardRunner(function_address); #else ORT_UNUSED_PARAMETER(obj); #endif diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py index fece1be20c96a..d9d1c467a10c1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py @@ -52,10 +52,9 @@ def enable_custom_autograd_support(to_enable=True): if to_enable is True and custom_autograd_function_enabler.state is False: if custom_autograd_function_enabler.already_enabled is False: # Initialize static objects needed to run custom autograd.Function's. - from ._custom_autograd_function_runner import call_python_backward_function, call_python_forward_function - register_forward_runner(call_python_forward_function) - register_backward_runner(call_python_backward_function) + register_forward_runner(torch_interop_utils.get_custom_function_forward_runner()) + register_backward_runner(torch_interop_utils.get_custom_function_backward_runner()) # Unregister all python functions automatically upon normal interpreter termination. atexit.register(unregister_python_functions) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 8efbe16d7d61d..f10416a9bb0f4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -71,10 +71,10 @@ def symbolic_wrapper(fn): def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None: - """Register a shape inference function for a torch.autograd.Function if there is staticmethod - "infer_shape" defined. + """Register schema summplementaries, for example custom shape inference function and + alias input function for a custom autograd.Function. - The signature of the shape inference function should be: + 1. The signature of the shape inference function should be: @staticmethod def infer_shape( node: onnx.NodeProto, @@ -91,7 +91,7 @@ def infer_shape( Be noted: we only pass in tensor inputs, and return tensor outputs, non-tensor inputs/outputs are ignored. - The signature of the alias input function should be: + 2. The signature of the alias input function should be: @staticmethod def alias_input(node_proto_str: str) -> Tuple[List[int], List[int]]: fw_alias_map = [1, -1, -1] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py deleted file mode 100644 index dd32e2aced561..0000000000000 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ /dev/null @@ -1,707 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - - -import sys -import warnings -from collections import OrderedDict -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -from torch.utils.dlpack import from_dlpack, to_dlpack - -from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils - -from ._fallback import ORTModuleFallbackException, ORTModuleIOError, _FallbackManager, wrap_exception # noqa: F401 -from ._utils import get_rank - - -def _log_warning(message: str): - """Configure the logger for PythonOp runner according to following rules. - 1. If multiple processes are used, the rank will be appended - to the logger name. - 2. The logger will be disabled for non-zero ranks. - """ - if get_rank() == 0: - warnings.warn(f"[rank-{get_rank()}] {message}") - - -class CustomFuncOpKernelInfo: - """Store the kernel-specific information retrieved with the first-time run.""" - - def __init__(self, kernel_invoke_id: str): - # kernel_invoke_id is a string contains session thread id, op kernel creation time stamp in ms, a random int, - # and address of op_kernel pointer. This can guarantee the uniqueness of the key in case of multiple - # instances of a same named PythonOp/PythonOpGrad in one session, or multiple sessions. - self.kernel_invoke_id = kernel_invoke_id - - # For the tensors generated from ORT backend, there is special handling here: - # 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), - # all such tensors will be cloned in case they are saved in context (but ORT backend is not aware of the - # reference, may release the content of the tensor before it is needed in backward). Once - # `autograd.Function.apply` completes, by checking the existence of the tensor in the saved_tensors, - # `_GlobalOpKernelInfoMap` is updated to save the input indices that are saved in context. - # 2. For the subsequent runs, if the input index is in `tensor_input_indices_to_save_in_ctx`, the tensor - # will be cloned before fed into `autograd.Function.apply` as input. - self.tensor_input_indices_to_save_in_ctx: Optional[List[int]] = None - - # To align with PyTorch `ctx.set_materialize_grads(False|True)`` - # materialize_grads_config is a map from output index to (device, dtype, shape) of the output tensor, used - # for materializing the gradient of the output tensor in backward. - self.materialize_grads: bool = False - self.materialize_grads_config: Optional[Dict[int, Tuple[torch.device, torch.dtype, torch.shape]]] = None - - # For the tensors generated from ORT backend, there is special handling here: - # 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), - # all such tensors will be cloned (with gradient) in case they are marked as dirty (if not cloned, but marked - # as dirty, PyTorch will complain the tensor is a leaf, should not be used for inplace update). Once - # `autograd.Function.apply` completes, by checking the existence of the tensor in the dirty_tensors, - # `_GlobalOpKernelInfoMap` is updated to save the input indices that are marked as dirty. - # 2. For the subsequent runs, if the input index is in `tensor_input_indices_for_mark_dirty`, the tensor - # will be cloned (with gradient) before fed into `autograd.Function.apply` as input. - self.tensor_input_indices_for_mark_dirty: Optional[List[int]] = None - - # A list of output indices that needs to be clone before returned, due to inplace update analysis. - self.output_indices_for_clone: Optional[List[int]] = None - - -# Store the kernel-specific information that cannot be retrieved and saved by PyTorch exporter. -# For the infos that can only be retrieved with real run, we try to collect them in the first time run. -# key: kernel_invoke_id, value: CustomFuncOpKernelInfo. -_GlobalOpKernelInfoMap: Dict[str, CustomFuncOpKernelInfo] = {} - - -def _process_inplace_outputs( - kernel_info: CustomFuncOpKernelInfo, - func_name: str, - input_tensors_of_kernel_run: Dict[int, Union[torch.Tensor, None]], - all_outputs_of_kernel_run: List[Union[torch.Tensor, any]], - all_outputs_to_tensor_inputs_reuse_map: List[int], - raw_input_tensors_used_inplace: Dict[int, Union[torch.Tensor, None]], - is_backward=False, -): - """Special handling for in-place reusing in forward or backward. - - Args: - kernel_info: kernel-specific information. - func_name: name of the autograd.Function. - input_tensors_of_kernel_run: all tensor input tensors used to run the autograd.Function forward/backward. - all_outputs_of_kernel_run: all outputs of the autograd.Function forward/backward. - all_outputs_to_tensor_inputs_reuse_map: a list of the same length of kernel outputs, each element representing - which input index it is reusing. If there is no reuse, the value is -1. - raw_input_tensors_used_inplace: a dict of raw input tensors marked as inplace in - `all_outputs_to_tensor_inputs_reuse_map`, the key is the tensor input index, value is the raw input tensor. - is_backward: indicates if this is backward or forward. - - Procedures: - 1. Detect all outputs to tensor inputs reuse mapping. - 2. Validate the detected inplace_map with the registered inplace_map in ORT. For the output tensor, - 2.0 If the reuse mapping value is the same in both inplace_map and detected inplace_map: - 2.0.1 Most likely, we don't need to do anything, except 2.0.2. - 2.0.2 Conditions: - > During forward run, - > The output tensor is reusing one of input tensors, - > The raw input tensor to be reused given from ORT is copied to run the forward kernels - (for two possible reasons: - a. the first time forward run, all inputs will be copied to detect - `tensor_input_indices_to_save_in_ctx`; - b. for every iteration, the input needs to be cloned because it is in - `tensor_input_indices_to_save_in_ctx`). - - In this case, need to copy the output tensor back to the raw input tensor, to make it compatible with - ORT statistically planned buffer reuse. - 2.1 If the reuse mapping value is NOT equal in both inplace_map and detected inplace_map: - 2.1.1 If the detected reuse input index is -1 (e.g. there is NO buffer reuse for this output), - while user specified reuse input index is NOT -1 (ORT planned the reuse), we raise an error. - 2.1.2 If the detected reuse input index is NOT -1 (e.g. there is buffer reuse for this output), - while user specified reuse input index is -1 (ORT did not plan the reuse). We will try to clone the - output tensor before returning to ORT, to align with ORT's NO Buffer reuse plan; otherwise, once the - input buffer is released by ORT memory planner, the output tensor read/write will be corrupted. - Raise a warning to notify users to update inplace_map explicitly for performance consideration. - 2.1.3 Other cases (for example user gives a wrong mapping index compared with detected ones), raise an - error. - 3. Do copies for 2.1.2 cases. - 4. Do copies for 2.0.2 cases. - """ - - log_prefix = f"{func_name}->{'Backward' if is_backward else 'Forward'}: " - input_tensor_address_list = [ - t.data_ptr() if isinstance(t, torch.Tensor) else -1 for t in input_tensors_of_kernel_run.values() - ] - if is_backward: - input_tensor_address_list = [-1, *input_tensor_address_list] # skip the context input - - is_first_time_init = kernel_info.output_indices_for_clone is None - # If this is the first time run, collect runtime tensor reuse mapping. - if is_first_time_init: - # Procedure 1: Detect all outputs to tensor inputs reuse mapping, according to `all_outputs_of_kernel_run` and - # `input_tensors_of_kernel_run`. - assert len(all_outputs_to_tensor_inputs_reuse_map) == len(all_outputs_of_kernel_run), ( - f"{log_prefix}all_outputs_to_tensor_inputs_reuse_map and kernel run outputs should have the same length." - f"all_outputs_to_tensor_inputs_reuse_map: {all_outputs_to_tensor_inputs_reuse_map}, " - f"kernel run outputs: {all_outputs_of_kernel_run}" - ) - - # Detect all outputs to tensor inputs reuse mapping. - detected_reuse_map = [-1] * (len(all_outputs_of_kernel_run)) - for output_index, arg in enumerate(all_outputs_of_kernel_run): - if not isinstance(arg, torch.Tensor): - continue - if arg.data_ptr() in input_tensor_address_list: - input_index = input_tensor_address_list.index(arg.data_ptr()) - detected_reuse_map[output_index] = input_index - - # Procedure 2: Validate the detected inplace_map with the registered inplace_map in ORT. - output_indices_for_clone = ( - [] - ) # collect the output indices that need to be cloned before returned in case 2.1.2. - for output_index, (detected_inplace_index, inplace_index) in enumerate( - zip(detected_reuse_map, all_outputs_to_tensor_inputs_reuse_map) - ): - if inplace_index == detected_inplace_index: - continue - - if ( - inplace_index in raw_input_tensors_used_inplace - and raw_input_tensors_used_inplace[inplace_index] is None - ): - # Use specified inplace input index, but the input tensor is None, which means the input is not - # a tensor, so we don't do further checks. - continue - - # If users register inplace_map (alloc planner will do buffer reuse), - # but detected inplace_map indicates it is NO inplace reusing, we raise an error. - if inplace_index != -1 and detected_inplace_index == -1: - raise RuntimeError( - f"{log_prefix}Fatal: " - f"ONNX Op attribute 'tensor_reuse_map' indicates {output_index}-th output is reusing input " - f"{inplace_index}, but detected inplace_map indicates it is NOT reusing any input. " - "Please update inplace_map explicitly to make it consistent " - f"to avoid undefined behavior due to ORT's memory reuse plan. " - f"inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, " - f"detected inplace_map: {detected_reuse_map}" - ) - - if inplace_index == -1 and detected_inplace_index != -1: - output_indices_for_clone.append(output_index) - continue - - raise RuntimeError( - f"{log_prefix}Fatal: " - f"ONNX Op attribute 'inplace_map' indicates {inplace_index}-th output is reusing " - f"input index {detected_inplace_index}, but detected inplace_map indicates it is reusing " - f"input index {inplace_index}. Please update inplace_map explicitly to avoid undefined behavior " - f"due to memory reuse. inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, " - f"detected inplace_map: {detected_reuse_map}" - ) - - kernel_info.output_indices_for_clone = output_indices_for_clone - - assert kernel_info.output_indices_for_clone is not None - - # Procedure 3: Do copies for 2.1.2 cases. - for output_index in kernel_info.output_indices_for_clone: - _log_warning( - f"{log_prefix}ONNX Op attribute " - f"'tensor_reuse_map' doesn't indicate {output_index}-th output is reusing any input, " - f"but detected inplace_map indicates it is reusing some input index. " - "A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. " - "Please update inplace_map explicitly to avoid such a copy." - ) - all_outputs_of_kernel_run[output_index] = all_outputs_of_kernel_run[output_index].detach().clone() - - # Procedure 4: Do copies for 2.0.2 cases. - if is_backward is False and ( - is_first_time_init - or kernel_info.tensor_input_indices_to_save_in_ctx - or kernel_info.tensor_input_indices_for_mark_dirty - ): - for raw_tensor_input_index, raw_input_tensor in raw_input_tensors_used_inplace.items(): - # raw_input_tensor can be None for backward run, but backward won't go here. - if not isinstance(raw_input_tensor, torch.Tensor): - continue - - # We did not do the check with tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty - # because even for those tensor indices not in - # tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty, we still need to do the - # copy for the first-time run. - if raw_input_tensor.data_ptr() == input_tensor_address_list[raw_tensor_input_index]: - # If the raw input tensor is not copied, we don't need this handling. - continue - - copied = False # for each tensor, we don't do the copy once. - output_indices_reusing_current_raw_input = [ - output_index - for output_index, input_index in enumerate(all_outputs_to_tensor_inputs_reuse_map) - if input_index == raw_tensor_input_index - ] - output_tensor_address = all_outputs_of_kernel_run[output_indices_reusing_current_raw_input[0]].data_ptr() - for output_index in output_indices_reusing_current_raw_input: - assert ( - output_tensor_address == all_outputs_of_kernel_run[output_index].data_ptr() - ), "Outputs reusing the same input tensor should have the same address." - - if not copied: - # Only need a copy once. - # Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False. - raw_input_tensor.requires_grad = False - raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index]) - _log_warning( - f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. " - f"{'Provide output to input reuse mapping to avoid the copy overhead.' if not is_first_time_init else ''}" - ) - copied = True - - all_outputs_of_kernel_run[output_index] = raw_input_tensor - - -def _get_context(forward_tensor_outputs: List[torch.Tensor]) -> Tuple[any, Optional[torch.Tensor]]: - """Search for context among all outputs. - - Note 1: All forward outputs of torch.autograd.Function shared the same gradient function pointer, - so here we just get the first tensor having grad_fn attribute. - (https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267) - - Note 2: Context can be None because NOT all torch.autograd.Function's are differentiable. The function - https://github.com/PyTorch/PyTorch/blob/d701357d921ef167d42c125e65b6f7da6be3ad0f/torch/csrc/autograd/custom_function.cpp#L209? - means if all output of the forward function is not differentiable, then grad_fn will be None (not be set). - - For example, - class Bar(torch.autograd.Function): - # A non-differentiable autograd Function whose forward output - # doesn't have grad_fn attribute. - @staticmethod - def forward(ctx, x): - y = torch.ones_like(x) - return y - - @staticmethod - def backward(ctx, dy): - dx = torch.zeros_like(dy) - return dx - - Returns: - ctx: context of the autograd.Function. - tensor: a tensor that owns the context. - - """ - ctx = None - first_tensor_output = None - for arg in forward_tensor_outputs: - if not isinstance(arg, torch.Tensor) or not hasattr(arg, "grad_fn"): - continue - - if arg.grad_fn is None: - # For the following case, it is possible grad_fn exists, but its value is None, - # so we need to continue to search for the first tensor having a non-None grad_fn. - # - # >>> w = torch.randn(5, 6) - # >>> hasattr(w, "grad_fn") - # True - # >>> w.grad_fn is None - # True - # >>> w, ... = CustomFunc.apply(w) # where CustomFunc forward just return w and other tensors. - # - # Then hasattr(w, "grad_fn") is True, but w.grad_fn is None. - continue - # Use the first context we see because all of arg's share the same one. - ctx = arg.grad_fn - first_tensor_output = arg - break - if first_tensor_output is not None: - assert ctx is not None, "ctx should not be None if first_tensor_output is not None." - return (ctx, first_tensor_output) - - -def _finalize_training_mode_forward( - kernel_invoke_id: str, - func_name: str, - input_tensors_used_for_fw_run: Dict[int, torch.Tensor], - forward_output_tensors: List[Union[torch.Tensor, None]], -): - """Complete the epilogue of forward runner for training mode. - - Args: - kernel_invoke_id: kernel_invoke_id of the PythonOp kernel unique id. - input_tensors_from_ort: input tensors generated from ORT backend. - forward_output_tensors: output tensors of the autograd.Function. - - Things to do: - 1. Try to get context from forward output tensors. - 2. Remove the gradient functions between the current autograd.Function and its input's gradient function, because - in ORT we don't depend on PyTorch's autograd engine. - 3. Register the current autograd.Function's gradient function into our PyNodeSharedPointerPool. - 4. Save kernel-specific information into _GlobalOpKernelInfoMap in the first-time kernel run. - """ - - ctx, tensor_owning_ctx = _get_context(forward_output_tensors) - - kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - - # ctx being None in training mode means the forward function is not differentiable, so backward is not needed. - if ctx is None: - # If this is the first time run, collect kernel-specific information. - if kernel_info.tensor_input_indices_to_save_in_ctx is None: - kernel_info.tensor_input_indices_to_save_in_ctx = [] - - if kernel_info.tensor_input_indices_for_mark_dirty is None: - kernel_info.tensor_input_indices_for_mark_dirty = [] - - return None - - # Filter out the None in the saved_tensors. - saved_tensors = [t for t in ctx.saved_tensors if t is not None] - - ctx.fw_kernel_invoke_id = kernel_invoke_id - - # If this is the first time run, collect kernel-specific information. - if kernel_info.tensor_input_indices_to_save_in_ctx is None: - kernel_info.tensor_input_indices_to_save_in_ctx = [] - if len(saved_tensors): - # Check tensors generated by ORT are in the saved_tensors or not. - # If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap. - kernel_info.tensor_input_indices_to_save_in_ctx = [ - tensor_input_index - for tensor_input_index, tensor in input_tensors_used_for_fw_run.items() - if any(tensor is saved_tensor for saved_tensor in saved_tensors) - ] - _log_warning( - f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to avoid extra copy in every iteration." - ) - kernel_info.materialize_grads = torch_interop_utils.get_materialize_grads(tensor_owning_ctx) - kernel_info.materialize_grads_config = OrderedDict() - if kernel_info.materialize_grads: - for output_index, tensor in enumerate(forward_output_tensors): - if isinstance(tensor, torch.Tensor): - kernel_info.materialize_grads_config[output_index] = ( - tensor.device, - tensor.dtype, - tensor.shape, - ) - - if kernel_info.tensor_input_indices_for_mark_dirty is None: - kernel_info.tensor_input_indices_for_mark_dirty = [] - # Check tensors generated by ORT are marked as dirty(for inplace update) or not. - # If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap. - are_tensors_marked_as_dirty = torch_interop_utils.are_tensors_marked_as_dirty( - tensor_owning_ctx, [t for t in input_tensors_used_for_fw_run.values()] - ) - kernel_info.tensor_input_indices_for_mark_dirty = [ - tensor_input_index - for is_dirty, (tensor_input_index, tensor) in zip( - are_tensors_marked_as_dirty, input_tensors_used_for_fw_run.items() - ) - if is_dirty is True - ] - _log_warning(f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to support leaf node do inplace update.") - - # FORWARD BACKWARD FUNCTION CONNECTIONS - # input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function - # ↓ ↑ - # autograd.Function apply() ------------> autograd.Function backward() - # ↓ | ↑ - # output_1, output_2 --- shared_ptr --- ↑ - # ↓ previous gradient function - - # We remove the edges starting between current autograd.Function's gradient function and - # it's input's gradient function (e.g. AccumulateGrad gradient function), then - # AccumulateGrad gradient function will be destroyed, releasing the reference to input_1 - # (https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21). - # The next edges are stored in Node, with which we can get next gradient function. - # https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527 - torch_interop_utils.clear_grad_fns_for_next_edges(tensor_owning_ctx, saved_tensors) - - # This is mainly to hold grad_fn references by registering it into our PyNodeSharedPointerPool. - torch_interop_utils.register_grad_fn_and_remove_from_autograd(id(ctx), tensor_owning_ctx) - - return ctx - - -def call_python_forward_function( - forward_function: Callable, - requires_grad_flags: List[bool], - tensor_type_flags: List[int], - is_training_mode: bool, - inplace_map: List[int], - kernel_invoke_id: str, - func_name: Union[bytes, str], - *args, -): - """ - This function bridges the gap between ORT variables and autograd.Function.apply. - It conducts basic casting from ORT to PyTorch (before calling "forward_function") and from PyTorch to ORT - (after calling "forward_function"). It also enable autograd in PyTorch. It formats returned outputs, - for example, dropping None's from forward_function's output list. - - The major difference between call_python_forward_function and call_python_backward_function is that - in the forward one, we have extra code to process autograd context from PyTorch. - - Args: - forward_function: pointer to autograd.Function.apply (e.g., MyReLU.apply). - requires_grad_flags: requires_grad_flags[i] indicates if the i-th arg needs gradient. - tensor_type_flags: tensor_type_flags[i] indicates the type of the i-th arg, 0 - non-tensor, 1 - tensor. - is_training_mode: indicates if this model is running under training mode. - inplace_map: a list of the same length of kernel outputs, each element represents which input index - it is reusing. If there is no reuse, the value is -1. - args: inputs to "backward_function". - """ - - try: - func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name - # If this is the first time run, collect runtime tensor reuse mapping. - is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap - if is_first_time_run: - kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) - _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info - - kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - - tensor_input_indices_to_save_in_ctx = kernel_info.tensor_input_indices_to_save_in_ctx - tensor_input_indices_for_mark_dirty = kernel_info.tensor_input_indices_for_mark_dirty - - # Collect the tensor address for all inputs used for run forward, used for reuse detection. - tensor_input_index = 0 - # If the input is reused, we need to save the raw input tensor for special handling. - raw_input_tensors_used_inplace = OrderedDict() # Orders matter here. - input_tensors_used_for_fw_run = OrderedDict() # Orders matter here. - - wrapped_args = [] - for _, (grad_flag, tensor_flag, arg) in enumerate(zip(requires_grad_flags, tensor_type_flags, args)): - if tensor_flag: - # Assume it's a DLPack tensor and convert it to PyTorch tensor. - wrapped_arg = from_dlpack(arg) - - if tensor_input_index in inplace_map: - raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg - - # Only requires gradient when running under training mode - # and the associated tensor has grad_flag=True (i.e., - # "requires_grad=True" in the original PyTorch script). - wrapped_arg.requires_grad = is_training_mode and grad_flag - - # Note1: - # If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the - # copy for all tensors. Otherwise, we only copy the tensors whose indices are in - # tensor_input_indices_to_save_in_ctx. - # Note2: - # For inference mode, we don't need to do the copy because ctx will be None, - # so nothing will be saved for ctx. - # Note3: - # To fix this issue: - # "a leaf Variable that requires grad has been used in an in-place operation." - # If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the - # copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for - # the tensors whose indices are in tensor_input_indices_for_mark_dirty. - if is_training_mode: - if is_first_time_run: - with torch.set_grad_enabled(True): - wrapped_arg = wrapped_arg.clone() - else: - is_input_index_saved_in_ctx = ( - tensor_input_indices_to_save_in_ctx is None - or tensor_input_index in tensor_input_indices_to_save_in_ctx - ) - is_input_index_marked_dirty = ( - tensor_input_indices_for_mark_dirty is None - or tensor_input_index in tensor_input_indices_for_mark_dirty - ) - if is_input_index_saved_in_ctx or is_input_index_marked_dirty: - # when with grad, the leaf tensor after clone will not be leaf. - with torch.set_grad_enabled(is_input_index_marked_dirty): - wrapped_arg = wrapped_arg.clone() - wrapped_arg.requires_grad = is_training_mode and grad_flag - - wrapped_args.append(wrapped_arg) - input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg - - tensor_input_index += 1 - else: - # Use non-tensor as is. It's a PyObject*. - wrapped_args.append(arg) - - with torch.set_grad_enabled(is_training_mode): - # Run autograd.Function.apply(...). - # TODO(pengwa): looks like we are assuming all outputs will be either Tensor or None. - # We should revisit if it is possible to support other types of output, for example int, or, etc. - # But that might also require some work in backend. - result = forward_function(*wrapped_args) - - results = [] - if isinstance(result, torch.Tensor): - results = [result] - elif isinstance(result, (tuple, list)): - results = [r for r in result] - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(result)}."), - ) - - ctx = None - if is_training_mode: - ctx = _finalize_training_mode_forward( - kernel_invoke_id, func_name, input_tensors_used_for_fw_run, results - ) - - final_rets = [ctx] - final_rets.extend(results) - - _process_inplace_outputs( - kernel_info, - func_name, - input_tensors_used_for_fw_run, - final_rets, - inplace_map, - raw_input_tensors_used_inplace, - ) - - dlpacks = [final_rets[0]] - dlpacks.extend(list(to_dlpack(value) if value is not None else None for value in final_rets[1:])) - - # Inside the returned list, the first element is context and the rest - # are DLPack tensors. - return tuple(dlpacks) - except Exception as e: - # Flush buffers. Otherwise, calling this from C++ may lose them. - print("Exception happens when running ", forward_function) - sys.stdout.flush() - sys.stderr.flush() - raise wrap_exception(ORTModuleFallbackException, e) # noqa: B904 - - -def call_python_backward_function( - backward_function: Callable, - requires_grad_flags: List[bool], - tensor_type_flags: List[int], - is_training_mode: bool, - inplace_map: List[int], - kernel_invoke_id: str, - func_name: Union[bytes, str], - *args, -): - """ - This function bridges the gap between ORT variables and autograd.Function.backward. - It conducts basic casting from ORT to PyTorch (before calling "backward_function") - and from PyTorch to ORT (after calling "backward_function"). It formats returned - outputs, example, dropping None's from backward_function's output list. - - Args: - backward_function: pointer to autograd.Function.backward (e.g., MyReLU.backward). - requires_grad_flags: requires_grad_flags[i] indicates if the i-th arg needs gradient. - tensor_type_flags: tensor_type_flags[i] indicates the type of the i-th arg. - is_training_mode: indicates if this model is running under training mode. - inplace_map: a list of the same length of kernel outputs, each element represents which input index - it is reusing. If there is no reuse, the value is -1. - args: inputs to "backward_function". - """ - func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name - with torch.no_grad(): - - def wrap_all_outputs(result): - if isinstance(result, torch.Tensor): - return [to_dlpack(result)] - elif isinstance(result, (tuple, list)): - return [to_dlpack(value) if value is not None else None for value in result] - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(result)}."), - ) - - try: - # If this is the first time run, collect runtime tensor reuse mapping. - if kernel_invoke_id not in _GlobalOpKernelInfoMap: - kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) - _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info - - kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - - # Backward inputs should not require gradients. - assert all(grad_flag == 0 for grad_flag in requires_grad_flags) - - # Prepare inputs for calling Python function. - ctx = args[0] - fw_kernel_invoke_id = ctx.fw_kernel_invoke_id - wrapped_args = [] - - # Collect the tensor address for all inputs used for run backward, used for reuse detection. - tensor_input_index = 1 # skip the context input - # If input is reused, we need to save the raw input tensor for special handling. - raw_input_tensors_used_inplace = OrderedDict() # Orders matter here. - input_tensors_used_for_bw_run = OrderedDict() # Orders matter here. - for grad_input_index, (grad_flag, tensor_flag, arg) in enumerate( - zip(requires_grad_flags, tensor_type_flags, args) - ): - # If an input is a tensor, it is possible we get a None also when it is optional as grad input. - if tensor_flag: - if arg is None: - if _GlobalOpKernelInfoMap[fw_kernel_invoke_id].materialize_grads: - config = _GlobalOpKernelInfoMap[fw_kernel_invoke_id].materialize_grads_config - # ignore the first input, which is the ctx. - device, dtype, shape = config[grad_input_index - 1] - wrapped_arg = torch.zeros(shape, device=device, dtype=dtype) - else: - wrapped_arg = arg - - if grad_input_index in inplace_map: - raw_input_tensors_used_inplace[tensor_input_index] = arg - - else: - # Assume it's a DLPack tensor# and convert it to PyTorch tensor. - wrapped_arg = from_dlpack(arg) - - if grad_input_index in inplace_map: - raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg - - # This may include None values. - input_tensors_used_for_bw_run[tensor_input_index] = wrapped_arg - - if wrapped_arg is not None: - # Only requires gradient when running under training mode - # and the associated tensor has grad_flag=True (i.e., - # "requires_grad=True" in the original PyTorch script). - wrapped_arg.requires_grad = is_training_mode and grad_flag - - wrapped_args.append(wrapped_arg) - tensor_input_index += 1 - else: - # Use non-tensor as is. It's a PyObject*. - wrapped_args.append(arg) - - # Call Python function. - result = backward_function(*wrapped_args) - - # Extract results as DLPack tensor list. - if isinstance(result, torch.Tensor): - result = [result] - elif isinstance(result, (tuple, list)): - result = list(result) - else: - raise wrap_exception( - ORTModuleIOError, - TypeError(f"ORTModule does not support the following model output type {type(result)}."), - ) - - _process_inplace_outputs( - kernel_info, - func_name, - input_tensors_used_for_bw_run, - result, - inplace_map, - raw_input_tensors_used_inplace, - is_backward=True, - ) - - wrapped_returned_args = wrap_all_outputs(result) - - torch_interop_utils.unregister_grad_fn(id(ctx)) - - return tuple(wrapped_returned_args) - except Exception as e: - # Flush buffers. Otherwise, calling this from C++ may lose them. - print("Exception happens when running ", backward_function) - sys.stdout.flush() - sys.stderr.flush() - raise wrap_exception(ORTModuleFallbackException, e) # noqa: B904 diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index d076ecacd6ba5..ff110c431d300 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -24,6 +24,10 @@ STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE = [1] +DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME = "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction" +DEEPSPEED_POST_BACKWARD_FUNCTION_NAME = "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction" +DEEPSPEED_LINEAR_FUNCTION_NAME = "deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3" + def post_processing_enable_zero_stage3_compat( exported_model: ModelProto, @@ -74,7 +78,10 @@ def _get_func_name(node: NodeProto) -> Optional[str]: STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, ) - from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction + from onnxruntime.training.utils.hooks._zero_offload_subscriber import ( + ORTZeROOffloadPostForwardFunction, + ORTZeROOffloadPreForwardFunction, + ) pre_forward_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) @@ -111,9 +118,10 @@ def _get_func_name(node: NodeProto) -> Optional[str]: if input_name == graph_input.name: index_offset_on_python_op_input.append(i) - assert ( - len(index_offset_on_python_op_input) == 1 - ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input} for node {pre_forward_pythonop_node.name}, input {graph_input.name}, {pre_forward_pythonop_node.input}" + assert len(index_offset_on_python_op_input) == 1, ( + f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input} for " + f"node {pre_forward_pythonop_node.name}, input {graph_input.name}, {pre_forward_pythonop_node.input}" + ) reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input) @@ -170,6 +178,34 @@ def _get_func_name(node: NodeProto) -> Optional[str]: exported_model.graph.input.insert(offset, new_input) exported_model.graph.node.insert(0, weight_pull_node) + # Update safe_run_mode attribute for PythonOp. + from onnxruntime.training.utils.hooks._subscriber_manager import _IncrementStep + + _allowed_unsafe_run_python_op_names = [ + get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction), + get_fully_qualified_class_name(ORTZeROOffloadPostForwardFunction), + func_full_qual_name, + DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME, + DEEPSPEED_POST_BACKWARD_FUNCTION_NAME, + DEEPSPEED_LINEAR_FUNCTION_NAME, + get_fully_qualified_class_name(_IncrementStep), + ] + + for node in exported_model.graph.node: + if node.op_type == "PythonOp": + func_name = None + safe_run_mode_attr = None + for attr in node.attribute: + if attr.name == "func_name": + func_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + if attr.name == "safe_run_mode": + safe_run_mode_attr = attr + + if func_name in _allowed_unsafe_run_python_op_names: + if safe_run_mode_attr: + node.attribute.remove(safe_run_mode_attr) + node.attribute.append(helper.make_attribute("safe_run_mode", 0)) + return exported_model @@ -227,12 +263,8 @@ def _simple_pass_through_infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes - register_shape_inference_function( - "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape - ) - register_shape_inference_function( - "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape - ) + register_shape_inference_function(DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME, _simple_pass_through_infer_shape) + register_shape_inference_function(DEEPSPEED_POST_BACKWARD_FUNCTION_NAME, _simple_pass_through_infer_shape) def _linear_infer_shape( node: NodeProto, @@ -246,7 +278,7 @@ def _linear_infer_shape( output_shape[-1] = shape2[-2] return [output_shape], [tensor_input_dtypes[0]] - register_shape_inference_function("deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape) + register_shape_inference_function(DEEPSPEED_LINEAR_FUNCTION_NAME, _linear_infer_shape) def _register_alias_input_functions(): @@ -274,8 +306,8 @@ def _alias_input(node_proto_str: str): return fw_alias_map, bw_alias_map - register_input_alias_function("deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _alias_input) - register_input_alias_function("deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _alias_input) + register_input_alias_function(DEEPSPEED_PRE_BACKWARD_FUNCTION_NAME, _alias_input) + register_input_alias_function(DEEPSPEED_POST_BACKWARD_FUNCTION_NAME, _alias_input) def _create_weight_retrieval_pythonop( diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.cc new file mode 100644 index 0000000000000..fa54b4929c784 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ctx_pool.h" +#include + +void register_grad_fn_and_remove_from_autograd(py::object ctx, at::Tensor target) { + uint32_t y = reinterpret_cast(ctx.ptr()); + size_t ctx_address = static_cast(y); + + torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); + PyNodeSharedPointerPool::GetInstance().RegisterGradFuncAndRemoveFromAutoGrad(ctx_address, autograd_meta); +} + +void unregister_grad_fn(py::object ctx) { + uint32_t y = reinterpret_cast(ctx.ptr()); + size_t ctx_address = static_cast(y); + PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address); +} + +void clear_all_grad_fns() { + PyNodeSharedPointerPool::GetInstance().ClearAll(); +} diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.h b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.h new file mode 100644 index 0000000000000..e7b101d987d7a --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/ctx_pool.h @@ -0,0 +1,96 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +// In PyTorch forward run (e.g. THPFunction_apply), ctx of type THPFunction* (which is also a PyObject*) +// is created (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L673). +// The ctx is used to run user-defined forward function and backward function as the first +// parameter. The same time, a cdata of type std::shared_ptr is created +// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L677), +// cdata is owned by: +// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor owns +// shared_pointer; TensorImpl owns std::unique_ptr; AutogradMeta +// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr, +// e.g, the so called gradient function.) +// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/variable.h#L194 +// b). the consumer operator of forward run outputs, will let its own PyNode/Node (gradient function) +// owns the grad_fn_ (of type std::shared_ptr) of all inputs that require grad. +// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L263 +// BUT, if we run torch computation within PythonOp, b) is lost. So for some cases, where forward outputs +// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr) references +// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0; +// Then when PythonOpGrad runs, segment fault. +// +// So we add b)'s reference in this Pool when forward run returns; dereference from this Pool when backward +// completes, then ~PyNode() is called, which subsequently calls ~THPFunction() destroying ctx. +class PyNodeSharedPointerPool { + public: + static PyNodeSharedPointerPool& GetInstance() { + static PyNodeSharedPointerPool pool; + return pool; + } + + void RegisterGradFuncAndRemoveFromAutoGrad(const size_t& ctx_address, + torch::autograd::AutogradMeta* autograd_meta) { + auto it = grad_fns_.find(ctx_address); + TORCH_CHECK(it == grad_fns_.end(), "should not register grad_fn twice for ctx ", ctx_address); + + // Add new entry if key hasn't been registered. + // After this, the grad_fn_ is removed from torch autograd. + grad_fns_.emplace(ctx_address, std::move(autograd_meta->grad_fn_)); + TORCH_CHECK(autograd_meta->grad_fn_ == nullptr, "fail to remove grad_fn_ from torch autograd for ctx ", + ctx_address); + } + + void UnRegisterGradFunc(const size_t& ctx_address) { + auto it = grad_fns_.find(ctx_address); + TORCH_CHECK(it != grad_fns_.end(), "fail to find grad_fn for ctx ", ctx_address); + + grad_fns_.erase(ctx_address); + } + + void ClearAll() { + grad_fns_.clear(); + } + + private: + PyNodeSharedPointerPool(){}; + ~PyNodeSharedPointerPool(){}; + + PyNodeSharedPointerPool(const PyNodeSharedPointerPool&) = delete; + PyNodeSharedPointerPool& operator=(const PyNodeSharedPointerPool&) = delete; + PyNodeSharedPointerPool(PyNodeSharedPointerPool&&) = delete; + PyNodeSharedPointerPool& operator=(PyNodeSharedPointerPool&&) = delete; + + std::unordered_map> grad_fns_; +}; + +void register_grad_fn_and_remove_from_autograd(py::object ctx, at::Tensor target); + +void unregister_grad_fn(py::object ctx); + +// Supposed to be cleared on python program exit to resolve the following issue: +// When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty, +// PyNode::release_variables() will be called. +// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L168) +// On The other hand, there is a known issue when acquiring GIL in pybind11 destructors, there will be +// probably a deadlock issue. (https://github.com/pybind/pybind11/issues/1446) +// The resolution here, we remove all maintained states before the program exits. + +// A known existing issue: when forward functions are called repeatedly without corresponding backward calls, +// grad functions keep accumulating without releasing, there might be memory (bound to those gradient functions) leaks. +// Ideally this usually won't happen in real training cases, so it should be fine. + +// We CANNOT explicitly clear grad functions before each forward pass to mitigate the known issue above. +// For example: +// loss1 = forward_run(inputs1) +// loss2 = forward_run(inputs2) +// loss = loss1 + loss2 +// loss.backward() +// If we clear grad functions at the beginning of the second `forward_run`, when `loss.backward()` runs, +// the backward path of `loss1` will fail to run PythonOpGrad ops (if there is any). +void clear_all_grad_fns(); diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc new file mode 100644 index 0000000000000..88e93b26e0e22 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.cc @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ctx_pool.h" +#include "custom_function_shared.h" +#include "custom_function_bw.h" + +#include +#include +#include + +#ifdef NVTX3_ENABLED +#include +#endif + +std::vector custom_function_backward_runner(const char* func_name_char, + void* callback, + const std::vector& requires_grad_flags, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& args) { + pybind11::gil_scoped_acquire gil; + + try { + std::string func_name(func_name_char); + std::string kernel_invoke_id(kernel_invoke_id_char); + bool is_backward = true; + std::string log_prefix = func_name + " -> " + (is_backward ? "Backward " : "Forward "); + + at::AutoGradMode enable_grad(false); + auto it = KernelInfoStore::GetInstance().GetKernelInfoMap().find(kernel_invoke_id); + if (it == KernelInfoStore::GetInstance().GetKernelInfoMap().end()) { + KernelInfoStore::GetInstance().GetKernelInfoMap().emplace( + kernel_invoke_id, + CustomFuncOpKernelInfo(kernel_invoke_id, safe_run_mode_enabled)); + } + + CustomFuncOpKernelInfo& kernel_info = KernelInfoStore::GetInstance().GetKernelInfoMap().at(kernel_invoke_id); + + std::unordered_map raw_input_tensors_used_inplace; + std::unordered_map input_tensors_used_for_bw_run; + + int tensor_input_index = 0; + std::vector raii_call_args; + raii_call_args.reserve(args.size()); + py::object ctx = py::reinterpret_borrow(args[0]); + raii_call_args.push_back(ctx); + for (size_t arg_index = 1; arg_index < args.size(); ++arg_index) { + if (tensor_type_flags[arg_index] != 1) { + raii_call_args.push_back(py::reinterpret_borrow(args[arg_index])); + continue; + } + + at::Tensor tensor; + bool is_dlpack = PyCapsule_IsValid(args[arg_index], "dltensor") != 0; + if (is_dlpack) { + tensor = torch::utils::tensor_fromDLPack(args[arg_index]); + } else { + TORCH_CHECK(args[arg_index] == Py_None, "Only None is supported for non-tensor input."); + PyObject* fw_kernel_invoke_id = PyObject_GetAttrString(ctx.ptr(), "fw_kernel_invoke_id"); + std::string fw_kernel_invoke_id_str = + py::cast(py::reinterpret_borrow(fw_kernel_invoke_id)); + CustomFuncOpKernelInfo& fw_kernel_info = + KernelInfoStore::GetInstance().GetKernelInfoMap().at(fw_kernel_invoke_id_str); + if (fw_kernel_info.materialize_grads) { + auto& config = fw_kernel_info.materialize_grads_config.at(arg_index - 1); + tensor = at::zeros(std::get<0>(config), std::get<1>(config)); // shift by 1 to skip context input. + } + } + + if (kernel_info.safe_run_enabled) { + bool is_input_used_inplace = std::find(inplace_map.begin(), inplace_map.end(), arg_index) != + inplace_map.end(); + if (is_input_used_inplace) { + raw_input_tensors_used_inplace[tensor_input_index] = tensor; + } + input_tensors_used_for_bw_run[tensor_input_index] = tensor; + } + + if (tensor.defined()) { + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); + } else { + raii_call_args.push_back(py::none()); + } + + tensor_input_index++; + } + + py::tuple call_args = py::cast(raii_call_args); + PyObject* result_pyobj; + { + at::AutoGradMode enable_grad(false); + result_pyobj = PyObject_CallObject(reinterpret_cast(callback), call_args.ptr()); + } + + if (PyErr_Occurred()) { + PyErr_Print(); + throw std::runtime_error("Python function execution fails with the above information."); + } + + if (!result_pyobj) { + throw std::runtime_error("Get null result"); + } + + py::object ret = py::reinterpret_steal(result_pyobj); + + std::vector all_outputs_of_kernel_run; + if (THPVariable_Check(ret.ptr())) { + all_outputs_of_kernel_run.push_back(ret); + } else { + TORCH_CHECK(PyTuple_Check(ret.ptr()), "Python function must return a tuple."); + all_outputs_of_kernel_run = ret.cast>(); + } + + if (kernel_info.safe_run_enabled) { + if (kernel_info.is_first_run) { + // key: tensor data address; + // value: if the tensor is defined it records the tensor input index, otherwise, -1. + std::unordered_map input_tensor_address_to_tensor_input_index_map; + input_tensor_address_to_tensor_input_index_map.reserve(input_tensors_used_for_bw_run.size()); + for (auto& input : input_tensors_used_for_bw_run) { + if (input.second.defined()) { + input_tensor_address_to_tensor_input_index_map.insert( + {{static_cast(reinterpret_cast(input.second.data_ptr())), + input.first + 1}}); /* skip the ctx input*/ + } + } + + detect_memory_reuse_once(kernel_info, + input_tensor_address_to_tensor_input_index_map, + all_outputs_of_kernel_run /*all_outputs_of_kernel_run*/, + inplace_map /*all_outputs_to_tensor_inputs_reuse_map*/, + raw_input_tensors_used_inplace, + log_prefix); + } + + process_inplace_outputs(kernel_info, + func_name, + input_tensors_used_for_bw_run, + inplace_map /*all_outputs_to_tensor_inputs_reuse_map*/, + raw_input_tensors_used_inplace, + is_backward /*is_backward*/, + log_prefix, + all_outputs_of_kernel_run /*all_outputs_of_kernel_run*/); + + unregister_grad_fn(ctx); + } + + std::vector rets; + for (auto& py_obj : all_outputs_of_kernel_run) { + PyObject* obj = py_obj.ptr(); + + if (!THPVariable_Check(obj)) { + Py_INCREF(obj); + rets.push_back(obj); + continue; + } + + DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(obj)); + rets.push_back(PyCapsule_New(dlMTensor, "dltensor", dlpack_capsule_destructor)); + } + + if (kernel_info.is_first_run) { + kernel_info.is_first_run = false; + } + return rets; + } catch (const std::exception& e) { + std::cerr << "custom_function_backward_runner failed with " << e.what() << std::endl; + throw; + } +} diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.h b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.h new file mode 100644 index 0000000000000..415f7cc1e5295 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_bw.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +std::vector custom_function_backward_runner(const char* func_name_char, + void* callback, + const std::vector& requires_grad_flags, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& args); diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc new file mode 100644 index 0000000000000..9e24022b8448d --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.cc @@ -0,0 +1,516 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ctx_pool.h" +#include "custom_function_shared.h" +#include "custom_function_fw.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef NVTX3_ENABLED +#include +#endif + +static void clear_grad_fns_for_next_edges(at::Tensor& target, + std::vector& saved_tensors) { + // For leaf tensor, there will be a AccumulateGrad (gradient function) created, which owns a + // reference to the tensor. + // For any user saved tensors (with save_for_backward), if the tensor is leaf, we put the map + // {AccumulateGrad*, Tensor*} into grad_fn_to_tensor_map. + std::unordered_map grad_fn_to_tensor_map; + for (auto& t : saved_tensors) { + auto grad_fn = t.grad_fn(); + if (!grad_fn) { + grad_fn = torch::autograd::impl::try_get_grad_accumulator(t); + if (grad_fn) { + TORCH_CHECK(grad_fn_to_tensor_map.find(grad_fn.get()) == grad_fn_to_tensor_map.end(), + "found AccumulateGrad* is used by more than one tensors."); + grad_fn_to_tensor_map.insert({grad_fn.get(), &t}); + } + } + } + + const auto& gradient_func_sptr = target.grad_fn(); + for (auto& edge : gradient_func_sptr->next_edges()) { + torch::autograd::Node* node_func = edge.function.get(); + // If we find the next gradient function is AccumulateGrad, we will check whether its owned + // tensors is in ctx.save_tensors or not. If yes, we skip it; otherwise, we clean the edge, which + // will release the AccumulateGrad function. + if (dynamic_cast(node_func)) { + if (grad_fn_to_tensor_map.find(node_func) != grad_fn_to_tensor_map.end()) { + // skip the edges that connect to saved_tensors. Because when unpack ctx.saved_tensors using + // following code in backward: + // input, = ctx.saved_tensors + // there is such a check: if the saved tensor is a leaf and requires grad, it should have grad accumulator. + // If we clean the edge, then an exception "RuntimeError: No grad accumulator for a saved leaf!" will be thrown + continue; + } else { + edge.function.reset(); + } + } + } +} + +static std::vector are_tensors_marked_as_dirty(at::Tensor& target, + std::vector& tensors_to_check) { + torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); + const auto& grad_fn = autograd_meta->grad_fn_; + auto py_node_fn = dynamic_cast(grad_fn.get()); + TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type."); + THPFunction* py_fn = (THPFunction*)py_node_fn->obj; + std::vector are_tensors_marked_dirty(tensors_to_check.size(), false); + if (!py_fn->dirty_tensors) + return are_tensors_marked_dirty; + + Py_ssize_t num_dirty = PyTuple_GET_SIZE(py_fn->dirty_tensors); + for (const auto j : c10::irange(tensors_to_check.size())) { + bool is_tensor_marked_dirty = false; + for (const auto i : c10::irange(num_dirty)) { + PyObject* obj = PyTuple_GET_ITEM(py_fn->dirty_tensors, i); + const auto& tensor = THPVariable_Unpack(obj); + if (tensor.is_same(tensors_to_check[j])) { + is_tensor_marked_dirty = true; + break; + } + } + + are_tensors_marked_dirty[j] = is_tensor_marked_dirty; + } + + return are_tensors_marked_dirty; +} + +std::optional try_to_get_tensor_owning_context(const py::tuple& forward_output_tensors) { + py::object ctx = py::none(); + std::optional first_tensor_output; + + for (size_t i = 0; i < forward_output_tensors.size(); ++i) { + PyObject* obj = forward_output_tensors[i].ptr(); + if (!THPVariable_Check(obj)) { + continue; + } + + at::Tensor t = THPVariable_Unpack(obj); + if (!t.grad_fn()) { + continue; + } + + // Be noted, in Python, we need additional check as below. + // For the following case, it is possible grad_fn exists, but its value is None, + // so we need to continue to search for the first tensor having a non-None grad_fn. + // + // >>> w = torch.randn(5, 6) + // >>> hasattr(w, "grad_fn") + // True + // >>> w.grad_fn is None + // True + // >>> w, ... = CustomFunc.apply(w) # where CustomFunc forward just return w and other tensors. + // + // Then hasattr(w, "grad_fn") is True, but w.grad_fn is None. + + first_tensor_output = t; + break; + } + + return first_tensor_output; +} + +void get_materialize_grads_once(const py::tuple& forward_output_tensors, + bool need_materialize_grads, + CustomFuncOpKernelInfo& kernel_info) { + kernel_info.materialize_grads = need_materialize_grads; + if (need_materialize_grads) { + for (size_t i = 0; i < forward_output_tensors.size(); ++i) { + PyObject* obj = forward_output_tensors[i].ptr(); + if (!THPVariable_Check(obj)) { + continue; + } + at::Tensor t = THPVariable_Unpack(obj); + kernel_info.materialize_grads_config.insert({i, {t.sizes().vec(), t.options()}}); + } + + static std::once_flag log_warning; + std::call_once(log_warning, []() { + std::cerr << "First-time run initialize kernel info including materialize_grads and materialize_grads_config." + << std::endl; + }); + } +} + +py::object finalize_training_mode_forward( + const std::unordered_map& input_tensors_used_for_fw_run, + const py::tuple& forward_output_tensors, + CustomFuncOpKernelInfo& kernel_info) { + std::optional tensor_owning_ctx = try_to_get_tensor_owning_context(forward_output_tensors); + + if (!tensor_owning_ctx.has_value()) { + // ctx being None in training mode means the forward function is not differentiable, so backward is not needed. + return py::none(); + } + + const std::shared_ptr& cdata = tensor_owning_ctx.value().grad_fn(); + auto py_node_fn = dynamic_cast(cdata.get()); + TORCH_CHECK(py_node_fn != nullptr, "cdata is not PyNode type."); + + // ret is THPFunction + THPFunction* py_fn = (THPFunction*)py_node_fn->obj; + py::object ret = py::reinterpret_steal(torch::autograd::functionToPyObject(cdata)); + + TORCH_CHECK(py_fn != nullptr, "cdata is not THPFunction type."); + + // The way we find saved tensor is aligned with + // "THPFunction_saved_tensors" and "unpack_saved_variables" in PyTorch. + std::vector saved_tensors; + int num_saved = py_fn->saved_variables.size(); + auto saved_for = py_fn->cdata.lock(); + TORCH_INTERNAL_ASSERT(saved_for); + + for (const auto i : c10::irange(num_saved)) { + auto unpacked_var = py_fn->saved_variables[i].unpack(saved_for); + if (unpacked_var.defined()) { + // TODO(pengwa): is it possible we do the copy on demand here instead of do blind + // copy and do detection at the first iteration. + saved_tensors.push_back(unpacked_var); + } + } + + if (kernel_info.is_first_run) { + std::cout << "666666666666666666666666. py_fn->materialize_grads:" << py_fn->materialize_grads << std::endl; + get_materialize_grads_once(forward_output_tensors, py_fn->materialize_grads, kernel_info); + + if (kernel_info.safe_run_enabled) { + for (auto& pair : input_tensors_used_for_fw_run) { + auto& tensor = pair.second; + bool found = false; + for (auto& t : saved_tensors) { + if (t.is_same(tensor)) { + found = true; + break; + } + } + kernel_info.tensor_input_indices_to_save_in_ctx[pair.first] = found; + } + + // Check tensors generated by ORT are marked as dirty(for inplace update) or not . + // If yes, save the input index of the tensor in the KernelInfoStore::GetInstance().GetKernelInfoMap(). + std::vector tensors_to_check; + tensors_to_check.reserve(input_tensors_used_for_fw_run.size()); + for (auto& pair : input_tensors_used_for_fw_run) { + tensors_to_check.push_back(pair.second); + } + + std::vector are_dirty = are_tensors_marked_as_dirty(tensor_owning_ctx.value(), tensors_to_check); + size_t index = 0; + for (auto& pair : input_tensors_used_for_fw_run) { + kernel_info.tensor_input_indices_for_mark_dirty[pair.first] = are_dirty[index]; + + index += 1; + } + + static std::once_flag log_warning; + std::call_once(log_warning, []() { + std::cerr << "First time run initialize kernel info including saved_for_forward, and mark_dirty infos." << std::endl; + }); + } + } + + // #FORWARD BACKWARD FUNCTION CONNECTIONS + // #input_1(leaf, constructed by from_dlpack) < -- --reference-- --AccumulateGrad gradient function + // # ↓ ↑ + // #autograd.Function apply()-- -- -- -- -- --> autograd.Function backward() + // # ↓ | ↑ + // #output_1, output_2-- - shared_ptr < PyNode> -- - ↑ + // # ↓ previous gradient function + + // #We remove the edges starting between current autograd.Function's gradient function and + // #it 's input' s gradient function(e.g.AccumulateGrad gradient function), then + // #AccumulateGrad gradient function will be destroyed, releasing the reference to input_1 + // #(https: //github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/functions/accumulate_grad.cpp#L21). + // #The next edges are stored in Node, with which we can get next gradient function. + // #https: // github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L527 + + clear_grad_fns_for_next_edges(tensor_owning_ctx.value(), saved_tensors); + + // This is mainly to hold grad_fn references by registering it into our PyNodeSharedPointerPool. + register_grad_fn_and_remove_from_autograd(ret, tensor_owning_ctx.value()); + + return ret; +} + +static py::object get_mockup_context_class() { + static py::object kclass_obj; + + if (!kclass_obj.ptr()) { + // Load the module object + auto module = + py::reinterpret_steal( + PyImport_ImportModule("onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils.fake_ctx")); + if (!module.ptr()) { + PyErr_Print(); + throw std::runtime_error("Fails to import the module."); + } + + auto python_class = py::reinterpret_steal(PyObject_GetAttrString(module.ptr(), "FakeContext")); + if (!PyCallable_Check(python_class.ptr())) { + throw std::runtime_error("Cannot instantiate the Python class"); + } + + kclass_obj = py::reinterpret_borrow(python_class.ptr()); + } + + return kclass_obj; +} + +std::vector custom_function_forward_runner(const char* func_name_char, + void* callback, + const std::vector& requires_grad_flags, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& args) { + try { + pybind11::gil_scoped_acquire gil; + + std::string func_name(func_name_char); + std::string kernel_invoke_id(kernel_invoke_id_char); + bool is_backward = false; + std::string log_prefix = func_name + " -> " + (is_backward ? "Backward " : "Forward "); + +#ifdef NVTX3_ENABLED + nvtxRangePushA(std::string(func_name + ".fw").c_str()); +#endif + + auto it = KernelInfoStore::GetInstance().GetKernelInfoMap().find(kernel_invoke_id); + if (it == KernelInfoStore::GetInstance().GetKernelInfoMap().end()) { + KernelInfoStore::GetInstance().GetKernelInfoMap().emplace( + kernel_invoke_id, + CustomFuncOpKernelInfo(kernel_invoke_id, safe_run_mode_enabled)); + } + + CustomFuncOpKernelInfo& kernel_info = KernelInfoStore::GetInstance().GetKernelInfoMap().at(kernel_invoke_id); + + std::unordered_map raw_input_tensors_used_inplace; + std::unordered_map input_tensors_used_for_fw_run; + + int tensor_input_index = 0; + std::vector raii_call_args; + if (kernel_info.safe_run_enabled) { + raii_call_args.reserve(args.size()); + } else { + auto python_class = get_mockup_context_class(); + // Creates an instance of the class + PyObject* object = PyObject_CallObject(python_class.ptr(), nullptr); + raii_call_args.reserve(args.size() + 1); + raii_call_args.push_back(py::reinterpret_steal(object)); + } + + for (size_t arg_index = 0; arg_index < args.size(); ++arg_index) { + bool is_tensor = (tensor_type_flags[arg_index] == 1); + if (!is_tensor) { + raii_call_args.push_back(py::reinterpret_borrow(args[arg_index])); + continue; + } + + // Assume it's a DLPack tensor and convert it to PyTorch tensor. + TORCH_CHECK(PyCapsule_IsValid(args[arg_index], "dltensor") != 0, "found invalid pycapsule"); + at::Tensor tensor = torch::utils::tensor_fromDLPack(args[arg_index]); + bool requires_grad = requires_grad_flags[arg_index] && is_training_mode; + tensor.requires_grad_(requires_grad); + + if (kernel_info.safe_run_enabled) { + bool is_input_used_inplace = (std::find(inplace_map.begin(), inplace_map.end(), tensor_input_index) != + inplace_map.end()); + if (is_input_used_inplace) { + raw_input_tensors_used_inplace[tensor_input_index] = tensor; + } + + if (kernel_info.is_first_run) { + at::Tensor tensor_clone; + if (is_training_mode) { + at::AutoGradMode enable_grad(true); + tensor_clone = tensor.clone(); + tensor_clone.requires_grad_(requires_grad); + } else { + tensor_clone = tensor; + } + + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor_clone))); + input_tensors_used_for_fw_run[tensor_input_index] = tensor_clone; + } else { + // Saving tensor for backward only affect the training. + bool is_input_index_saved_in_ctx = + is_training_mode && kernel_info.tensor_input_indices_to_save_in_ctx.at(tensor_input_index); + + bool is_input_index_marked_dirty = + kernel_info.tensor_input_indices_for_mark_dirty.at(tensor_input_index); + + if (is_input_index_saved_in_ctx || is_input_index_marked_dirty) { + at::AutoGradMode enable_grad(is_input_index_marked_dirty); + auto wrapped_arg = tensor.clone(); + wrapped_arg.requires_grad_(requires_grad); + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(wrapped_arg))); + input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg; + } else { + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); + input_tensors_used_for_fw_run[tensor_input_index] = tensor; + } + } + } else { + raii_call_args.push_back(py::reinterpret_steal(THPVariable_Wrap(tensor))); + } + + tensor_input_index++; + } + + if (kernel_info.safe_run_enabled && kernel_info.is_first_run) { + // Initialize some kernel info for the first run. + for (const auto i : c10::irange(input_tensors_used_for_fw_run.size())) { + kernel_info.tensor_input_indices_to_save_in_ctx.insert({{i, false}}); + kernel_info.tensor_input_indices_for_mark_dirty.insert({{i, false}}); + } + } + +#ifdef NVTX3_ENABLED + nvtxRangePushA(std::string(func_name + ".call_func").c_str()); +#endif + + py::tuple call_args = py::cast(raii_call_args); + PyObject* result_pyobj; + { + at::AutoGradMode enable_grad(is_training_mode && kernel_info.safe_run_enabled); + result_pyobj = PyObject_CallObject(reinterpret_cast(callback), call_args.ptr()); + } + +#ifdef NVTX3_ENABLED + nvtxRangePop(); +#endif + + if (PyErr_Occurred()) { + PyErr_Print(); + } + + if (!result_pyobj) { + throw std::runtime_error("Get null result"); + } + + py::object ret = py::reinterpret_steal(result_pyobj); + + py::tuple forward_outputs; + if (THPVariable_Check(ret.ptr())) { // Don't check be tensor? + forward_outputs = py::make_tuple(ret); + } else { + TORCH_CHECK(PyTuple_Check(ret.ptr()), "Python function must return a tuple."); + forward_outputs = ret.cast(); + } + + py::object ctx; + if (is_training_mode) { +#ifdef NVTX3_ENABLED + std::string tag3 = func_name + ".ctx"; + nvtxRangePushA(tag3.c_str()); +#endif + if (kernel_info.safe_run_enabled) { + ctx = finalize_training_mode_forward(input_tensors_used_for_fw_run, forward_outputs, kernel_info); + if (!ctx.is_none()) { + PyObject_SetAttrString(ctx.ptr(), "fw_kernel_invoke_id", py::cast(kernel_invoke_id).ptr()); + } + } else { + if (kernel_info.is_first_run) { + bool need_materialize_grads = true; + get_materialize_grads_once(forward_outputs, need_materialize_grads, kernel_info); + } + + ctx = call_args[0]; + PyObject_SetAttrString(ctx.ptr(), "fw_kernel_invoke_id", py::cast(kernel_invoke_id).ptr()); + } + +#ifdef NVTX3_ENABLED + nvtxRangePop(); +#endif + } else { + ctx = py::none(); + } + + std::vector all_outputs_of_kernel_run; + all_outputs_of_kernel_run.reserve(forward_outputs.size() + 1); + all_outputs_of_kernel_run.push_back(ctx); + for (size_t i = 0; i < forward_outputs.size(); ++i) { + all_outputs_of_kernel_run.push_back(forward_outputs[i]); + } + + if (kernel_info.safe_run_enabled) { + if (kernel_info.is_first_run) { + // key: tensor data address; + // value: if the tensor is defined it records the tensor input index, otherwise, -1. + std::unordered_map input_tensor_address_to_tensor_input_index_map; + input_tensor_address_to_tensor_input_index_map.reserve(input_tensors_used_for_fw_run.size()); + for (auto& input : input_tensors_used_for_fw_run) { + if (input.second.defined()) { + input_tensor_address_to_tensor_input_index_map.insert( + {{static_cast(reinterpret_cast(input.second.data_ptr())), input.first}}); + } + } + + detect_memory_reuse_once(kernel_info, + input_tensor_address_to_tensor_input_index_map, + all_outputs_of_kernel_run /*all_outputs_of_kernel_run*/, + inplace_map /*all_outputs_to_tensor_inputs_reuse_map*/, + raw_input_tensors_used_inplace, + log_prefix); + } + + process_inplace_outputs(kernel_info, + func_name, + input_tensors_used_for_fw_run, + inplace_map /*all_outputs_to_tensor_inputs_reuse_map*/, + raw_input_tensors_used_inplace, + false /*is_backward*/, + log_prefix, + all_outputs_of_kernel_run /*all_outputs_of_kernel_run*/); + } + +#ifdef NVTX3_ENABLED + nvtxRangePushA(std::string(func_name + ".final").c_str()); +#endif + + std::vector rets; + rets.reserve(all_outputs_of_kernel_run.size()); + for (auto& py_obj : all_outputs_of_kernel_run) { + PyObject* obj = py_obj.ptr(); + + if (!THPVariable_Check(obj)) { + Py_INCREF(obj); + rets.push_back(obj); + continue; + } + + DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(obj)); + rets.push_back(PyCapsule_New(dlMTensor, "dltensor", dlpack_capsule_destructor)); + } + +#ifdef NVTX3_ENABLED + nvtxRangePop(); +#endif + + if (kernel_info.is_first_run) { + kernel_info.is_first_run = false; + } + +#ifdef NVTX3_ENABLED + nvtxRangePop(); +#endif + + return rets; + } catch (const std::exception& e) { + std::cerr << "custom_function_forward_runner failed with " << e.what() << std::endl; + throw; + } +} diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.h b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.h new file mode 100644 index 0000000000000..5a908e4cd4e7f --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_fw.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +std::vector custom_function_forward_runner(const char* func_name_char, + void* callback, + const std::vector& requires_grad_flags, + const std::vector& tensor_type_flags, + const bool is_training_mode, + const std::vector& inplace_map, + const char* kernel_invoke_id_char, + const bool safe_run_mode_enabled, + const std::vector& tensor_args); diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.cc new file mode 100644 index 0000000000000..f7698b74ab462 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.cc @@ -0,0 +1,213 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ctx_pool.h" +#include "custom_function_shared.h" +#include +#include + +/** + * @brief Special handling for in-place reusing in forward or backward. + * @param kernel_info kernel-specific information. + * @param input_tensor_address_to_tensor_input_index_map + * @param all_outputs_of_kernel_run all outputs of the MSDomain::PythonOp/PythonOpGrad. + * @param all_outputs_to_tensor_inputs_reuse_map + * @param raw_input_tensors_used_inplace a dict of raw input tensors marked as inplace in + `all_outputs_to_tensor_inputs_reuse_map`, the key is the tensor input index, value is the raw input tensor. + * @param log_prefix + * + * Detection procedures: + * 1. Detect all outputs to tensor inputs reuse mapping. + * 2. Validate the detected inplace_map with the registered inplace_map in ORT. For the output tensor, + * 2.0 If the reuse mapping value is the same in both inplace_map and detected inplace_map: + * 2.0.1 Most likely, we don't need to do anything, except 2.0.2. + * 2.0.2 Conditions: + * > During forward run, + * > The output tensor is reusing one of input tensors, + * > The raw input tensor to be reused given from ORT is copied to run the forward kernels + * (for two possible reasons: + * a. the first time forward run, all inputs will be copied to detect + * `tensor_input_indices_to_save_in_ctx`; + * b. for every iteration, the input needs to be cloned because it is in + * `tensor_input_indices_to_save_in_ctx`). + * + * In this case, need to copy the output tensor back to the raw input tensor, to make it compatible with + * ORT statistically planned buffer reuse. + * 2.1 If the reuse mapping value is NOT equal in both inplace_map and detected inplace_map: + * 2.1.1 If the detected reuse input index is -1 (e.g. there is NO buffer reuse for this output), + * while user specified reuse input index is NOT -1 (ORT planned the reuse), we raise an error. + * 2.1.2 If the detected reuse input index is NOT -1 (e.g. there is buffer reuse for this output), + * while user specified reuse input index is -1 (ORT did not plan the reuse). We will try to clone the + * output tensor before returning to ORT, to align with ORT's NO Buffer reuse plan; otherwise, once the + * input buffer is released by ORT memory planner, the output tensor read/write will be corrupted. + * Raise a warning to notify users to update inplace_map explicitly for performance consideration. + * 2.1.3 Other cases (for example user gives a wrong mapping index compared with detected ones), raise an + * error. + * 3. Do copies for 2.1.2 cases. + * 4. Do copies for 2.0.2 cases. + */ +void detect_memory_reuse_once( + CustomFuncOpKernelInfo& kernel_info, + const std::unordered_map& input_tensor_address_to_tensor_input_index_map, + const std::vector& all_outputs_of_kernel_run, + const std::vector& all_outputs_to_tensor_inputs_reuse_map, + const std::unordered_map& raw_input_tensors_used_inplace, + const std::string& log_prefix) { + // Procedure 1: Detect all outputs to tensor inputs reuse mapping, according to `all_outputs_of_kernel_run` and + // `input_tensors_of_kernel_run`. + + TORCH_CHECK(all_outputs_to_tensor_inputs_reuse_map.size() == all_outputs_of_kernel_run.size(), + log_prefix + + "all_outputs_to_tensor_inputs_reuse_map and kernel run outputs sizes not expected:" + + std::to_string(all_outputs_to_tensor_inputs_reuse_map.size()) + " vs " + + std::to_string(all_outputs_of_kernel_run.size())); + + // Detect all outputs to tensor inputs reuse mapping. + std::vector detected_reuse_map(all_outputs_of_kernel_run.size(), -1); + for (size_t output_index = 0; output_index < all_outputs_of_kernel_run.size(); ++output_index) { + py::object arg = all_outputs_of_kernel_run[output_index]; + if (!THPVariable_Check(arg.ptr())) { + continue; + } + at::Tensor t = THPVariable_Unpack(arg.ptr()); + size_t t_data_address = static_cast(reinterpret_cast(t.data_ptr())); + if (input_tensor_address_to_tensor_input_index_map.find(t_data_address) != input_tensor_address_to_tensor_input_index_map.end()) { + int tensor_input_index = input_tensor_address_to_tensor_input_index_map.at(t_data_address); + TORCH_CHECK(tensor_input_index != -1, "Reused tensor input index should not be -1"); + detected_reuse_map[output_index] = tensor_input_index; + } + } + + // Procedure 2: Validate the detected inplace_map with the registered inplace_map in ORT. + // collect the output indices that need to be cloned before returned in case 2.1.2. + for (size_t output_index = 0; output_index < all_outputs_of_kernel_run.size(); ++output_index) { + int detected_inplace_index = detected_reuse_map[output_index]; + int inplace_index = all_outputs_to_tensor_inputs_reuse_map[output_index]; + + if (inplace_index == detected_inplace_index) { + continue; + } + + if (raw_input_tensors_used_inplace.count(inplace_index) && + !raw_input_tensors_used_inplace.at(inplace_index).defined()) { + // Use specified inplace input index, but the input tensor is None, which means the input is not + // a tensor, so we don't do further checks. + continue; + } + + // If users register inplace_map (alloc planner will do buffer reuse), + // but detected inplace_map indicates it is NO inplace reusing, we raise an error. + if (inplace_index != -1 && detected_inplace_index == -1) { + throw std::runtime_error( + log_prefix + "Fatal: ONNX Op attribute 'tensor_reuse_map' indicates " + + std::to_string(output_index) + "-th output is reusing input " + + std::to_string(inplace_index) + ", but detected inplace_map indicates it is NOT reusing any input. " + + "Please update inplace_map explicitly to make it consistent " + + "to avoid undefined behavior due to ORT's memory reuse plan. " + + +"detected reused input index: " + std::to_string(detected_inplace_index)); + } + + if (inplace_index == -1 && detected_inplace_index != -1) { + std::cout << log_prefix << "ONNX Op attribute " + << "'tensor_reuse_map' doesn't indicate " << std::to_string(output_index) + << "-th output is reusing any input, " + << "but detected inplace_map indicates it is reusing input index " + << std::to_string(detected_inplace_index) + << ". A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. " + << "Please update inplace_map explicitly to avoid such a copy." << std::endl; + + kernel_info.output_indices_for_clone.push_back(output_index); + continue; + } + + throw std::runtime_error( + log_prefix + "Fatal: ONNX Op attribute 'tensor_reuse_map' indicates " + + std::to_string(output_index) + "-th output is reusing input " + std::to_string(inplace_index) + + " but detected inplace_map indicates it is reusing input index " + + std::to_string(detected_inplace_index) + + ". Please update inplace_map explicitly to avoid undefined behavior due to memory reuse."); + } +} + +void process_inplace_outputs( + const CustomFuncOpKernelInfo& kernel_info, + const std::string& func_name, + const std::unordered_map& input_tensors_used_for_fw_run, + const std::vector& all_outputs_to_tensor_inputs_reuse_map, + const std::unordered_map& raw_input_tensors_used_inplace, + bool is_backward, + const std::string& log_prefix, + std::vector& all_outputs_of_kernel_run) { + // Procedure 3: Do copies for 2.1.2 cases. + for (const size_t& output_index : kernel_info.output_indices_for_clone) { + at::Tensor t = THPVariable_Unpack(all_outputs_of_kernel_run[output_index].ptr()); + auto pp = py::reinterpret_steal(THPVariable_Wrap(t.detach().clone())); + all_outputs_of_kernel_run[output_index] = pp; + } + + // Procedure 4: Do copies for 2.0.2 cases. + if (!is_backward && kernel_info.safe_run_enabled) { + for (auto& pair : raw_input_tensors_used_inplace) { + auto raw_tensor_input_index = pair.first; + auto raw_input_tensor = pair.second; + // raw_input_tensor can be None for backward run, but backward won't go here. + if (!raw_input_tensor.defined()) { + continue; + } + + // We did not do the check with tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty + // because even for those tensor indices not in + // tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty, we still need to do the + // copy for the first-time run. + if (raw_input_tensor.data_ptr() == input_tensors_used_for_fw_run.at(raw_tensor_input_index).data_ptr()) { + // If the raw input tensor is not copied, we don't need this handling. + continue; + } + + // for each tensor, we don't do the copy once. + bool copied = false; + std::vector output_indices_reusing_current_raw_input; + for (size_t output_index = 0; output_index < all_outputs_to_tensor_inputs_reuse_map.size(); ++output_index) { + if (all_outputs_to_tensor_inputs_reuse_map[output_index] == raw_tensor_input_index) { + output_indices_reusing_current_raw_input.push_back(output_index); + } + } + + auto output_tensor_address = + THPVariable_Unpack(all_outputs_of_kernel_run[output_indices_reusing_current_raw_input[0]].ptr()).data_ptr(); + for (size_t& output_index : output_indices_reusing_current_raw_input) { + auto t = THPVariable_Unpack(all_outputs_of_kernel_run[output_index].ptr()); + TORCH_CHECK(output_tensor_address == t.data_ptr(), + "Outputs reusing the same input tensor should have the same address."); + + if (!copied) { + // Only need a copy once. + // Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False. + raw_input_tensor.requires_grad_(false); + raw_input_tensor.copy_(t); + + // Comment below for debugging. + // std::cout << "Copy output tensor " << output_index << " to raw input tensor " << raw_tensor_input_index << "." + // << (!kernel_info.is_first_run + // ? "Provide output to input reuse mapping to avoid the copy overhead." + // : "") + // << std::endl; + copied = true; + } + + all_outputs_of_kernel_run[output_index] = py::reinterpret_steal(THPVariable_Wrap(raw_input_tensor)); + } + } + } +} + +void dlpack_capsule_destructor(PyObject* data) { + if (!PyCapsule_IsValid(data, "dltensor")) { + // early out, see DLPack spec: if a consuming library sets the capsule + // name to something else, they own it and we don't need to do anything + return; + } + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + dlMTensor->deleter(const_cast(dlMTensor)); +} diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.h b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.h new file mode 100644 index 0000000000000..c1c1930aac4cd --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/custom_function_shared.h @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include + +// Uncomment this line to enable NVTX profiling +// #define NVTX3_ENABLED 1 + +class CustomFuncOpKernelInfo { + public: + CustomFuncOpKernelInfo(const std::string& invoke_id, bool safe_run) { + kernel_invoke_id = invoke_id; + safe_run_enabled = safe_run; + } + + // kernel_invoke_id is a string contains session thread id, op kernel creation time stamp in ms, a random int, + // and address of op_kernel pointer. This can guarantee the uniqueness of the key in case of multiple + // instances of a same named PythonOp/PythonOpGrad in one session, or multiple sessions. + std::string kernel_invoke_id; + + // For the tensors generated from ORT backend, there is special handling here: + // 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), + // all such tensors will be cloned in case they are saved in context (but ORT backend is not aware of the + // reference, may release the content of the tensor before it is needed in backward). Once + // `autograd.Function.apply` completes, by checking the existence of the tensor in the saved_tensors, + // `_GlobalOpKernelInfoMap` is updated to save the input indices that are saved in context. + // 2. For the subsequent runs, if the input index is in `tensor_input_indices_to_save_in_ctx`, the tensor + // will be cloned before fed into `autograd.Function.apply` as input. + std::unordered_map tensor_input_indices_to_save_in_ctx; + + // To align with PyTorch `ctx.set_materialize_grads(False|True)`, default to be true. + // materialize_grads_config is a map from output index to (device, dtype, shape) of the output tensor, used + // for materializing the gradient of the output tensor in backward. + bool materialize_grads{true}; + // key: output index, value: (shape, tensor options including device, layerout, data types, etc) + std::unordered_map, c10::TensorOptions>> materialize_grads_config; + + // For the tensors generated from ORT backend, there is special handling here: + // 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), + // all such tensors will be cloned (with gradient) in case they are marked as dirty (if not cloned, but marked + // as dirty, PyTorch will complain the tensor is a leaf, should not be used for inplace update). Once + // `autograd.Function.apply` completes, by checking the existence of the tensor in the dirty_tensors, + // `_GlobalOpKernelInfoMap` is updated to save the input indices that are marked as dirty. + // 2. For the subsequent runs, if the input index is in `tensor_input_indices_for_mark_dirty`, the tensor + // will be cloned (with gradient) before fed into `autograd.Function.apply` as input. + std::unordered_map tensor_input_indices_for_mark_dirty; + + // A list of output indices that needs to be clone before returned, due to inplace update analysis. + std::vector output_indices_for_clone; + + bool is_first_run{true}; + bool safe_run_enabled{false}; +}; + +void detect_memory_reuse_once( + CustomFuncOpKernelInfo& kernel_info, + const std::unordered_map& input_tensor_address_to_tensor_input_index_map, + const std::vector& all_outputs_of_kernel_run, + const std::vector& all_outputs_to_tensor_inputs_reuse_map, + const std::unordered_map& raw_input_tensors_used_inplace, + const std::string& log_prefix); + +void process_inplace_outputs( + const CustomFuncOpKernelInfo& kernel_info, + const std::string& func_name, + const std::unordered_map& input_tensors_used_for_fw_run, + const std::vector& all_outputs_to_tensor_inputs_reuse_map, + const std::unordered_map& raw_input_tensors_used_inplace, + bool is_backward, + const std::string& log_prefix, + std::vector& all_outputs_of_kernel_run); + +void dlpack_capsule_destructor(PyObject* data); + +class KernelInfoStore { + public: + static KernelInfoStore& GetInstance() { + static KernelInfoStore instance; + return instance; + } + + std::unordered_map& GetKernelInfoMap() { + return kernel_info_map_; + } + + private: + std::unordered_map kernel_info_map_; +}; diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/fake_ctx.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/fake_ctx.py new file mode 100644 index 0000000000000..d295c68c2a155 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/fake_ctx.py @@ -0,0 +1,13 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + + +class FakeContext: + """A mock up class used to represent ctx in unsfafe mode run. + The reason we need ctx to be Python class is: users could assign any attribute to ctx. + """ + + def __init__(self): + pass diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py index 3b6d6050c4c17..fa72f3b134917 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py @@ -8,13 +8,30 @@ from setuptools import Extension, setup # noqa: F401 from torch.utils import cpp_extension -filename = os.path.join(os.path.dirname(__file__), "torch_interop_utils.cc") +source_filenames = [ + "torch_interop_utils.cc", + "ctx_pool.cc", + "custom_function_bw.cc", + "custom_function_fw.cc", + "custom_function_shared.cc", +] + +cur_file_dir = os.path.dirname(__file__) + +header_filenames = [ + # "/usr/local/cuda/include/", # uncomment this line to build nvtx support, + cur_file_dir, +] + extra_compile_args = {"cxx": ["-O3"]} setup( name="torch_interop_utils", ext_modules=[ cpp_extension.CppExtension( - name="torch_interop_utils", sources=[filename], extra_compile_args=extra_compile_args + name="torch_interop_utils", + sources=[os.path.join(cur_file_dir, filename) for filename in source_filenames], + extra_compile_args=extra_compile_args, + include_dirs=header_filenames, ) ], cmdclass={"build_ext": cpp_extension.BuildExtension}, diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc index d36720100e57a..979c409f08074 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc @@ -1,190 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include -#include -// In PyTorch forward run (e.g. THPFunction_apply), ctx of type THPFunction* (which is also a PyObject*) -// is created (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L673). -// The ctx is used to run user-defined forward function and backward function as the first -// parameter. The same time, a cdata of type std::shared_ptr is created -// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L677), -// cdata is owned by: -// a). forward run output tensors as grad_fn_ property. (The full hierarchy is: Tensor owns -// shared_pointer; TensorImpl owns std::unique_ptr; AutogradMeta -// manages grad_/grad_fn_/grad_accumulator_. Among them, grad_fn_ is std::shared_ptr, -// e.g, the so called gradient function.) -// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/variable.h#L194 -// b). the consumer operator of forward run outputs, will let its own PyNode/Node (gradient function) -// owns the grad_fn_ (of type std::shared_ptr) of all inputs that require grad. -// https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/function.h#L263 -// BUT, if we run torch computation within PythonOp, b) is lost. So for some cases, where forward outputs -// are not used and freed before backward function runs, the grad_fn_ (std::shared_ptr) references -// in a) will be released. Without b)'s reference, grad_fn_ release PyNode as reference count reach 0; -// Then when PythonOpGrad runs, segment fault. -// -// So we add b)'s reference in this Pool when forward run returns; dereference from this Pool when backward -// completes, then ~PyNode() is called, which subsequently calls ~THPFunction() destroying ctx. -class PyNodeSharedPointerPool { - public: - static PyNodeSharedPointerPool& GetInstance() { - static PyNodeSharedPointerPool pool; - return pool; - }; +#include "ctx_pool.h" +#include "custom_function_fw.h" +#include "custom_function_bw.h" - void RegisterGradFuncAndRemoveFromAutoGrad(const size_t& ctx_address, - torch::autograd::AutogradMeta* autograd_meta) { - auto it = grad_fns_.find(ctx_address); - TORCH_CHECK(it == grad_fns_.end(), "should not register grad_fn twice for ctx ", ctx_address); - - // Add new entry if key hasn't been registered. - // After this, the grad_fn_ is removed from torch autograd. - grad_fns_.emplace(ctx_address, std::move(autograd_meta->grad_fn_)); - TORCH_CHECK(autograd_meta->grad_fn_ == nullptr, "fail to remove grad_fn_ from torch autograd for ctx ", - ctx_address); - }; - - void UnRegisterGradFunc(const size_t& ctx_address) { - auto it = grad_fns_.find(ctx_address); - TORCH_CHECK(it != grad_fns_.end(), "fail to find grad_fn for ctx ", ctx_address); - - grad_fns_.erase(ctx_address); - }; - - void ClearAll() { - grad_fns_.clear(); - } - - private: - PyNodeSharedPointerPool(){}; - ~PyNodeSharedPointerPool(){}; - - PyNodeSharedPointerPool(const PyNodeSharedPointerPool&) = delete; - PyNodeSharedPointerPool& operator=(const PyNodeSharedPointerPool&) = delete; - PyNodeSharedPointerPool(PyNodeSharedPointerPool&&) = delete; - PyNodeSharedPointerPool& operator=(PyNodeSharedPointerPool&&) = delete; - - std::unordered_map> grad_fns_; -}; - -void clear_grad_fns_for_next_edges(at::Tensor target, std::vector saved_tensors) { - // For leaf tensor, there will be a AccumulateGrad (gradient function) created, which owns a - // reference to the tensor. - // For any user saved tensors (with save_for_backward), if the tensor is leaf, we put the map - // {AccumulateGrad*, Tensor*} into grad_fn_to_tensor_map. - std::unordered_map grad_fn_to_tensor_map; - for (auto& t : saved_tensors) { - auto grad_fn = t.grad_fn(); - if (!grad_fn) { - grad_fn = torch::autograd::impl::try_get_grad_accumulator(t); - if (grad_fn) { - TORCH_CHECK(grad_fn_to_tensor_map.find(grad_fn.get()) == grad_fn_to_tensor_map.end(), - "found AccumulateGrad* is used by more than one tensors."); - grad_fn_to_tensor_map.insert({grad_fn.get(), &t}); - } - } - } - - const auto& gradient_func_sptr = target.grad_fn(); - for (auto& edge : gradient_func_sptr->next_edges()) { - torch::autograd::Node* node_func = edge.function.get(); - // If we find the next gradient function is AccumulateGrad, we will check whether its owned - // tensors is in ctx.save_tensors or not. If yes, we skip it; otherwise, we clean the edge, which - // will release the AccumulateGrad function. - if (dynamic_cast(node_func)) { - if (grad_fn_to_tensor_map.find(node_func) != grad_fn_to_tensor_map.end()) { - // skip the edges that connect to saved_tensors. Because when unpack ctx.saved_tensors using - // following code in backward: - // input, = ctx.saved_tensors - // there is such a check: if the saved tensor is a leaf and requires grad, it should have grad accumulator. - // If we clean the edge, then an exception "RuntimeError: No grad accumulator for a saved leaf!" will be thrown - continue; - } else { - edge.function.reset(); - } - } - } -} - -void register_grad_fn_and_remove_from_autograd(size_t ctx_address, at::Tensor target) { - torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); - PyNodeSharedPointerPool::GetInstance().RegisterGradFuncAndRemoveFromAutoGrad(ctx_address, autograd_meta); -} - -void unregister_grad_fn(size_t ctx_address) { - PyNodeSharedPointerPool::GetInstance().UnRegisterGradFunc(ctx_address); -} - -// Supposed to be cleared on python program exit to resolve the following issue: -// When training program exits, PyNodeSharedPointerPool destructor is called, if grad_fns_ is not empty, -// PyNode::release_variables() will be called. -// (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/python_function.cpp#L168) -// On The other hand, there is a known issue when acquiring GIL in pybind11 destructors, there will be -// probably a deadlock issue. (https://github.com/pybind/pybind11/issues/1446) -// The resolution here, we remove all maintained states before the program exits. - -// A known existing issue: when forward functions are called repeatedly without corresponding backward calls, -// grad functions keep accumulating without releasing, there might be memory (bound to those gradient functions) leaks. -// Ideally this usually won't happen in real training cases, so it should be fine. - -// We CANNOT explicitly clear grad functions before each forward pass to mitigate the known issue above. -// For example: -// loss1 = forward_run(inputs1) -// loss2 = forward_run(inputs2) -// loss = loss1 + loss2 -// loss.backward() -// If we clear grad functions at the beginning of the second `forward_run`, when `loss.backward()` runs, -// the backward path of `loss1` will fail to run PythonOpGrad ops (if there is any). -void clear_all_grad_fns() { - PyNodeSharedPointerPool::GetInstance().ClearAll(); -} - -bool get_materialize_grads(at::Tensor target) { - torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); - const auto& grad_fn = autograd_meta->grad_fn_; - auto py_node_fn = dynamic_cast(grad_fn.get()); - TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type."); - THPFunction* py_fn = (THPFunction*)py_node_fn->obj; - return py_fn->materialize_grads; -} - -std::vector are_tensors_marked_as_dirty(at::Tensor target, std::vector tensors_to_check) { - torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); - const auto& grad_fn = autograd_meta->grad_fn_; - auto py_node_fn = dynamic_cast(grad_fn.get()); - TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type."); - THPFunction* py_fn = (THPFunction*)py_node_fn->obj; - std::vector are_tensors_marked_dirty(tensors_to_check.size(), false); - if (!py_fn->dirty_tensors) - return are_tensors_marked_dirty; - - Py_ssize_t num_dirty = PyTuple_GET_SIZE(py_fn->dirty_tensors); - for (const auto j : c10::irange(tensors_to_check.size())) { - bool is_tensor_marked_dirty = false; - for (const auto i : c10::irange(num_dirty)) { - PyObject* obj = PyTuple_GET_ITEM(py_fn->dirty_tensors, i); - const auto& tensor = THPVariable_Unpack(obj); - if (tensor.is_same(tensors_to_check[j])) { - is_tensor_marked_dirty = true; - break; - } - } - - are_tensors_marked_dirty[j] = is_tensor_marked_dirty; - } - - return are_tensors_marked_dirty; -} +size_t get_custom_function_forward_runner() { return reinterpret_cast(&custom_function_forward_runner); } +size_t get_custom_function_backward_runner() { return reinterpret_cast(&custom_function_backward_runner); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("register_grad_fn_and_remove_from_autograd", ®ister_grad_fn_and_remove_from_autograd, - "Increase grad_fn shared pointer reference."); - m.def("unregister_grad_fn", &unregister_grad_fn, "Release grad_fn shared pointer reference."); m.def("clear_all_grad_fns", &clear_all_grad_fns, "Clear all grad_fn shared pointer references."); - m.def("clear_grad_fns_for_next_edges", &clear_grad_fns_for_next_edges, - "Remove reference on next edges' gradient functions."); - m.def("get_materialize_grads", &get_materialize_grads, "Return whether materialize_grads is enabled or not."); - m.def("are_tensors_marked_as_dirty", &are_tensors_marked_as_dirty, "Return whether the tensors are marked dirty or not."); + m.def("get_custom_function_forward_runner", &get_custom_function_forward_runner, "Get custom function forward runner."); + m.def("get_custom_function_backward_runner", &get_custom_function_backward_runner, "Get custom function backward runner."); } diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index 244557c3c1072..b4a518d573998 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. # __init__.py + from onnxruntime.training.utils.ptable import PTable from onnxruntime.training.utils.torch_io_helper import ( ORTModelInputOutputSchemaType, @@ -10,6 +11,11 @@ extract_data_and_schema, unflatten_data_using_schema, ) +from onnxruntime.training.utils.torch_profile_utils import ( + nvtx_function_decorator, + torch_nvtx_range_pop, + torch_nvtx_range_push, +) from onnxruntime.training.utils.torch_type_map import ( onnx_dtype_to_pytorch_dtype, pytorch_scalar_type_to_pytorch_dtype, @@ -22,6 +28,9 @@ "ORTModelInputOutputSchemaType", "extract_data_and_schema", "unflatten_data_using_schema", + "torch_nvtx_range_push", + "torch_nvtx_range_pop", + "nvtx_function_decorator", "pytorch_type_to_onnx_dtype", "onnx_dtype_to_pytorch_dtype", "pytorch_scalar_type_to_pytorch_dtype", diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 61f3b20224a72..e6004319ef5ea 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -17,7 +17,10 @@ from onnxruntime.training.utils import ( ORTModelInputOutputType, extract_data_and_schema, + nvtx_function_decorator, pytorch_type_to_onnx_dtype, + torch_nvtx_range_pop, + torch_nvtx_range_push, unflatten_data_using_schema, ) @@ -173,6 +176,7 @@ def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, sta raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") +@nvtx_function_decorator def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.parameter.Parameter]: """Retrieve the parameters for this module. @@ -187,6 +191,7 @@ def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.par return partitioned_params +@nvtx_function_decorator def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: """Retrieve all the parameters that are offloaded.""" from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -199,6 +204,10 @@ def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.p return all_offloaed_params +# Used to cache the map avoid repeated loop up (X us) overhead during training. +_ModuleToParametersRefs: Dict[torch.nn.Module, List[torch.nn.parameter.Parameter]] = OrderedDict() + + class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): """This function is a common bridge to call original PyTorch's pre_forward_function""" @@ -227,8 +236,7 @@ def forward( tensor_list: the list of tensors, the first args_tensor_count tensors are args, the next kwargs_tensor_count tensors are kwargs, the rest are the parameters for offload. """ - args_tensors = tensor_list[:args_tensor_count] - kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count] + torch_nvtx_range_push("ORTZeROOffloadPreForwardFunction::forward") # For PyTorch runs, the sizes are all 0, it does not need a gradient because # param._detach().requires_grad_(False) is called. @@ -241,41 +249,31 @@ def forward( ctx.dtypes = [p.dtype for p in passed_in_param_tensors] ctx.devices = [p.device for p in passed_in_param_tensors] - args = unflatten_data_using_schema(args_tensors, args_schema) - kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) - # We will re-retrieve the parameter tensors other than use the one passed in input (of size 0 for # those partitioned params). # This is required for ORT run because in ORT graph, the tensor of size 0 will always be size 0 # (this step is not necessary for PyTorch run, because PyTorch will re-use the same tensor # while .data got updated to full-sized data after pre_forward_with_kwargs_function is called). - partitioned_params = _get_params_for_current_module(module) + if module not in _ModuleToParametersRefs: + _ModuleToParametersRefs[module] = _get_params_for_current_module(module) + partitioned_params = _ModuleToParametersRefs[module] ctx.partitioned_params = partitioned_params - assert len(partitioned_params) == len(passed_in_param_tensors) - - f_ret = pre_forward_with_kwargs_function(module, args, kwargs) - - if f_ret is None: - updated_args, updated_kwargs = args, kwargs - else: - assert isinstance(f_ret, tuple) - updated_args, updated_kwargs = f_ret - + pre_forward_with_kwargs_function(module) ctx.module = module - - updated_args_tensors, _ = extract_data_and_schema(updated_args) - updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs) - - rets = tuple(updated_args_tensors + updated_kwargs_tensors) + rets = tuple(tensor_list[: args_tensor_count + kwargs_tensor_count]) rets += tuple([p.detach().requires_grad_(p.requires_grad) for p in partitioned_params]) # PyTorch exporter does not support an empty list of tensors, so we have this check. assert len(rets) != 0 + + torch_nvtx_range_pop() return rets @staticmethod def backward(ctx, *grads): + torch_nvtx_range_push("ORTZeROOffloadPreForwardFunction::backward") + updated_grads = grads input_count = len(updated_grads) - len(ctx.partitioned_params) @@ -302,6 +300,7 @@ def backward(ctx, *grads): zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad) + torch_nvtx_range_pop() return (None, None, None, None, None, None, *zero_grads) @staticmethod @@ -381,6 +380,8 @@ def forward( output_tensors: the list of tensors. """ + torch_nvtx_range_push("ORTZeROOffloadPostForwardFunction::forward") + outputs = unflatten_data_using_schema(output_tensors, output_schema) # STAGE3WARN#3: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. @@ -394,15 +395,20 @@ def forward( ctx.module = module ctx.pre_backward_function = pre_backward_function rets = [o.detach().requires_grad_(o.requires_grad) for o in updated_output_tensors] + torch_nvtx_range_pop() return tuple(rets) @staticmethod def backward(ctx, *grads): + torch_nvtx_range_push("ORTZeROOffloadPostForwardFunction::backward") + updated_args = grads if ctx.pre_backward_function is not None: ret = ctx.pre_backward_function(ctx.module, grads) if ret is not None: updated_args = ret + + torch_nvtx_range_pop() return (None, None, None, None, *updated_args) @staticmethod @@ -467,6 +473,7 @@ def __init__(self, offloader, one_time_init: _ZeROOffloadOneTimeInitializer, ena self._functions = _ZeROOffloadFunctions(one_time_init, self._offloader) self._enable_debug_info = enable_debug_info + @nvtx_function_decorator def pre_forward_module_apply_impl( self, run_rtx: RuntimeStates, @@ -499,17 +506,14 @@ def pre_forward_module_apply_impl( args_tensor_count = len(args_tensors) kwargs_tensor_count = len(kwargs_tensors) - def _wrap_pre_forward_module_hook(module, args, kwargs): - rets = _pre_forward_module_hook(module, args) - updated_args, updated_kwargs = args, kwargs - if rets is not None: - updated_args = rets + @nvtx_function_decorator + def _wrap_pre_forward_module_hook(module): + empty = [] + _pre_forward_module_hook(module, *empty) # STAGE3WARN#5: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. module.ds_grads_remaining = 0 - return updated_args, updated_kwargs - # Need to pass the parameters as input to let the exporter trace the related weights for # current ORTZeROOffloadPreForwardFunction partitioned_params = _get_params_for_current_module(module) @@ -545,6 +549,7 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): return updated_args, updated_kwargs + @nvtx_function_decorator def post_forward_module_apply_impl( self, run_rtx: RuntimeStates, @@ -563,6 +568,7 @@ def post_forward_module_apply_impl( _post_forward_module_hook = self._functions.get("_post_forward_module_hook") + @nvtx_function_decorator def _wrap_post_forward_module_hook(module, input, outputs): # STAGE3WARN#6: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. from deepspeed.runtime.zero.partition_parameters import is_zero_param @@ -580,7 +586,11 @@ def _wrap_post_forward_module_hook(module, input, outputs): self._check_all_tensor(outputs_tensors, module, "post_forward_module_apply_impl input check") updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply( - module, _wrap_post_forward_module_hook, None, outputs_schema, *outputs_tensors + module, + _wrap_post_forward_module_hook, + None, + outputs_schema, + *outputs_tensors, ) self._check_all_tensor(updated_outputs_tensors, module, "post_forward_module_apply_impl output check") @@ -598,6 +608,7 @@ def _wrap_post_forward_module_hook(module, input, outputs): return args, updated_outputs + @nvtx_function_decorator def post_forward_outmost_module_apply_impl( self, run_rtx: RuntimeStates, @@ -611,7 +622,11 @@ def post_forward_outmost_module_apply_impl( self._check_all_tensor(outputs_tensors, module, "post_forward_outmost_module_apply_impl input check") updated_outputs_tensors = ORTZeROOffloadPostForwardFunction.apply( - module, _end_of_forward_hook, None, outputs_schema, *outputs_tensors + module, + _end_of_forward_hook, + None, + outputs_schema, + *outputs_tensors, ) self._check_all_tensor(updated_outputs_tensors, module, "post_forward_outmost_module_apply_impl output check") @@ -620,6 +635,7 @@ def post_forward_outmost_module_apply_impl( updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) return args, updated_outputs + @nvtx_function_decorator def _check_all_tensor(self, tensor_list: Tuple[torch.Tensor], module: torch.nn.Module, name: str): if not self._enable_debug_info: return diff --git a/orttraining/orttraining/python/training/utils/torch_io_helper.py b/orttraining/orttraining/python/training/utils/torch_io_helper.py index 6d7d978e90054..34cc1ca942a8c 100644 --- a/orttraining/orttraining/python/training/utils/torch_io_helper.py +++ b/orttraining/orttraining/python/training/utils/torch_io_helper.py @@ -10,6 +10,8 @@ import torch +from onnxruntime.training.utils.torch_profile_utils import nvtx_function_decorator + class PrimitiveType: """Helper class for Python primitive types.""" @@ -122,6 +124,7 @@ def _warn_of_constant_inputs(data): ) +@nvtx_function_decorator def extract_data_and_schema( data: ORTModelInputOutputType, constant_as_tensor=False, device: Optional[torch.device] = None ) -> Tuple[List[torch.Tensor], ORTModelInputOutputSchemaType]: @@ -230,6 +233,7 @@ def _flatten_from_data(data: ORTModelInputOutputType, prefix_name: str = ""): return flatten_tensor_data, schemas +@nvtx_function_decorator def unflatten_data_using_schema( data: List[torch.Tensor], schema: ORTModelInputOutputSchemaType ) -> ORTModelInputOutputType: diff --git a/orttraining/orttraining/python/training/utils/torch_profile_utils.py b/orttraining/orttraining/python/training/utils/torch_profile_utils.py new file mode 100644 index 0000000000000..382d7dac142fe --- /dev/null +++ b/orttraining/orttraining/python/training/utils/torch_profile_utils.py @@ -0,0 +1,28 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import torch + + +def torch_nvtx_range_push(msg): + if hasattr(torch.cuda.nvtx, "range_push"): + torch.cuda.nvtx.range_push(msg) + + +def torch_nvtx_range_pop(): + if hasattr(torch.cuda.nvtx, "range_pop"): + torch.cuda.nvtx.range_pop() + + +def nvtx_function_decorator(func): + """Function decorator to record the start and end of NVTX range.""" + + def wrapped_fn(*args, **kwargs): + torch_nvtx_range_push(func.__qualname__) + ret_val = func(*args, **kwargs) + torch_nvtx_range_pop() + return ret_val + + return wrapped_fn diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index 958c7d94c4241..bd4fce2cde144 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1533,9 +1533,8 @@ def _run_step(model, input): import warnings - for index in range(10): - count = 0 - with warnings.catch_warnings(record=True) as w: + for _ in range(10): + with warnings.catch_warnings(record=True): input = torch.randn(output_size, device=device, dtype=torch.float) pt_prediction = _run_step(pt_model, input) ort_prediction = _run_step(ort_model, input) @@ -1543,16 +1542,6 @@ def _run_step(model, input): assert_values_are_close(ort_prediction, pt_prediction, rtol=1e-04, atol=1.0) assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-5) - for i in range(len(w)): - msg = str(w[i].message) - if "Add input index to _GlobalOpKernelInfoMap" in msg: - count += 1 - - if index == 0: - assert count == 2 - else: - assert count == 0 - class DupNamedFunction(torch.autograd.Function): @staticmethod diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index 41f4a41a7c38a..3c5ac56cb139a 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -51,6 +51,9 @@ void PythonOpBase::Init(const OpKernelInfo& info) { ORT_THROW_IF_ERROR(info.GetAttr("func_name", &name_)); is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); + + safe_run_mode_enabled_ = static_cast(info.GetAttrOrDefault("safe_run_mode", static_cast(1))); + ORT_THROW_IF_ERROR(info.GetAttr("input_convention", &input_convention_)); input_requires_grads_ = info.GetAttrsOrDefault( @@ -144,7 +147,8 @@ void PythonOpBase::RunForward(OpKernelContext* context, // Invoke Python calls. TorchProxy::GetInstance().Forward( name_, - OrtTorchFunctionPool::GetInstance().GetForwardCore(name_), + safe_run_mode_enabled_ ? OrtTorchFunctionPool::GetInstance().GetForwardCore(name_) + : OrtTorchFunctionPool::GetInstance().GetUnsafeForwardCore(name_), input_requires_grads_, args, arg_positions_, @@ -153,6 +157,7 @@ void PythonOpBase::RunForward(OpKernelContext* context, is_training_mode_, all_output_to_tensor_input_reuse_map_, kernel_invoke_id_, + safe_run_mode_enabled_, diff_ctx, returned_ortvalues); @@ -301,7 +306,8 @@ void PythonOpBase::SetOtherOutputs(OpKernelContext* context, std::vector().DataRaw(); - const void* input_tensor_address = context->Input(all_output_to_tensor_input_reuse_map_[output_index])->DataRaw(); + const void* input_tensor_address = + context->Input(all_output_to_tensor_input_reuse_map_[output_index])->DataRaw(); ORT_ENFORCE(tensor_address == input_tensor_address, "PythonOp inplace tensor address mismatch, output index: ", output_index, ", input index: ", all_output_to_tensor_input_reuse_map_[output_index]); @@ -327,7 +333,7 @@ void PythonOpGradBase::Init(const OpKernelInfo& info) { output_tensor_requires_grads_ = info.GetAttrsOrDefault("output_tensor_requires_grads", std::vector()); ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(), "backward tensor output count mismatch"); - + safe_run_mode_enabled_ = static_cast(info.GetAttrOrDefault("safe_run_mode", static_cast(1))); std::vector tensor_output_to_tensor_input_alias_map = info.GetAttrsOrDefault("tensor_reuse_map", std::vector((info.node().OutputDefs().size()), -1)); @@ -371,6 +377,7 @@ void PythonOpGradBase::RunBackward(OpKernelContext* context, const_arg_positions_, all_output_to_tensor_input_reuse_map_, kernel_invoke_id_, + safe_run_mode_enabled_, returned_ortvalues); OrtTorchFunctionPool::GetInstance().UnregisterContext(*context_index_ptr); diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h index d4a53a223abf1..4353859b56735 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h @@ -149,6 +149,8 @@ class PythonOpBase { // Output types of MyReLU.apply(...). std::vector output_tensor_types_; + bool safe_run_mode_enabled_{true}; + private: void AddPrimitiveTypeScalarArgs(); void AddInputTupleArgs(); @@ -193,6 +195,8 @@ class PythonOpGradBase { // Memory reuse map for all outputs. std::vector all_output_to_tensor_input_reuse_map_; + bool safe_run_mode_enabled_{true}; + private: void SetPositions(); diff --git a/setup.py b/setup.py index 44c97937ebe2a..0c2eb19e82c87 100644 --- a/setup.py +++ b/setup.py @@ -488,7 +488,7 @@ def finalize_options(self): ) package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"] - package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"] + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc", "*.h"] package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"] package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [ "*.cpp", From fc9ecb59dbf6ac647bb1a70727a45e9267fefa90 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 15 Dec 2023 08:47:52 -0800 Subject: [PATCH 06/11] Add Windows ARM build jobs to post merge pipeline (#18832) ### Description Add Windows ARM build jobs to post merge pipeline to valid our code is still compatible with these build settings. --- .../azure-pipelines/post-merge-jobs.yml | 146 +++++++++++++++++- .../azure-pipelines/templates/win-ci.yml | 4 +- 2 files changed, 144 insertions(+), 6 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index e7138e628a52b..bdce0991d6b86 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -10,9 +10,13 @@ stages: UseWebPoolName: true WebCpuPoolName: 'Onnxruntime-Win-CPU-2022' -# This stage is to test if the combined build works on +# The follow section has 12 different build jobs that can be divided into 3 groups: +# 1. Default CPU build with normal win32 linking, without ORT extension +# 2. Default CPU build with wcos linking(use apiset), without ORT extension +# 3. Default CPU build with normal win32 linking with ORT extension +# Each group has 4 jobs that cover: # o Windows ARM64 -# o Windows ARM64EC +# o Windows ARM # o Windows x64 # o Windows x86 # Now we don't have coverage for ARM64EC yet. Will add it. @@ -24,12 +28,26 @@ stages: buildArch: x86 msbuildPlatform: Win32 packageName: x86 - buildparameter: --use_extensions --enable_onnx_tests + buildparameter: --enable_onnx_tests runTests: true buildJava: false buildNodejs: false ort_build_pool_name: 'onnxruntime-Win-CPU-2022' +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm_default + buildArch: x64 + msbuildPlatform: arm + packageName: arm + buildparameter: --arm --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + - template: templates/win-ci.yml parameters: DoCompliance: false @@ -38,7 +56,7 @@ stages: buildArch: x64 msbuildPlatform: arm64 packageName: arm64 - buildparameter: --build_nodejs --arm64 --use_extensions --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + buildparameter: --build_nodejs --arm64 --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe runTests: false buildJava: false buildNodejs: true @@ -52,6 +70,126 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64 + buildparameter: --build_java --build_nodejs --enable_onnx_tests + runTests: true + buildJava: true + buildNodejs: true + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_x86_wcos + artifact_name_suffix: '-wcos' + buildArch: x86 + msbuildPlatform: Win32 + packageName: x86 + buildparameter: --enable_onnx_tests --enable_wcos + runTests: true + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm_wcos + artifact_name_suffix: '-wcos' + buildArch: x64 + msbuildPlatform: arm + packageName: arm + buildparameter: --arm --enable_onnx_tests --enable_wcos --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm64_wcos + artifact_name_suffix: '-wcos' + buildArch: x64 + msbuildPlatform: arm64 + packageName: arm64 + buildparameter: --build_nodejs --enable_wcos --arm64 --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: true + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_x64_wcos + artifact_name_suffix: '-wcos' + buildArch: x64 + msbuildPlatform: x64 + packageName: x64 + buildparameter: --build_java --build_nodejs --enable_onnx_tests --enable_wcos + runTests: true + buildJava: true + buildNodejs: true + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_x86_extension + artifact_name_suffix: '-extension' + buildArch: x86 + msbuildPlatform: Win32 + packageName: x86 + buildparameter: --enable_onnx_tests + runTests: true + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm_extension + artifact_name_suffix: '-extension' + buildArch: x64 + msbuildPlatform: arm + packageName: arm + buildparameter: --arm --use_extensions --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: false + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_arm64_extension + artifact_name_suffix: '-extension' + buildArch: x64 + msbuildPlatform: arm64 + packageName: arm64 + buildparameter: --build_nodejs --arm64 --use_extensions --enable_onnx_tests --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe + runTests: false + buildJava: false + buildNodejs: true + ort_build_pool_name: 'onnxruntime-Win-CPU-2022' + +- template: templates/win-ci.yml + parameters: + DoCompliance: false + DoEsrp: false + stage_name_suffix: CPU_x64_extension + artifact_name_suffix: '-extension' + buildArch: x64 + msbuildPlatform: x64 + packageName: x64 buildparameter: --build_java --build_nodejs --use_extensions --enable_onnx_tests runTests: true buildJava: true diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index fd5f61b82a5a8..89c481f267e64 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -193,7 +193,7 @@ stages: - template: nodejs-artifacts-package-and-publish-steps-windows.yml parameters: arch: ${{ parameters.packageName }} - artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.packageName }}' + artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.packageName }}${{ parameters.artifact_name_suffix }}' DoEsrp: ${{ parameters.DoEsrp }} #Upload protoc.exe, which will be used in nuget build for generating C# files @@ -260,7 +260,7 @@ stages: displayName: 'Publish Java temp binaries' inputs: pathtoPublish: '$(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}' - artifactName: 'drop-onnxruntime-java-win-${{ parameters.packageName }}' + artifactName: 'drop-onnxruntime-java-win-${{ parameters.packageName }}${{parameters.artifact_name_suffix}}' - ${{ if eq(parameters['DoCompliance'], 'true') }}: - task: CredScan@3 From d795fc636ce92c29d95d85cf0faf506baeadd46b Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 15 Dec 2023 08:48:15 -0800 Subject: [PATCH 07/11] FIX: Our cmake script didn't check googletest's hash (#18826) --- cmake/external/onnxruntime_external_deps.cmake | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 0fa5163dc06bf..78f63227c8392 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -47,8 +47,8 @@ if (onnxruntime_BUILD_UNIT_TESTS) FetchContent_Declare( googletest URL ${DEP_URL_googletest} - FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest URL_HASH SHA1=${DEP_SHA1_googletest} + FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest ) endif() @@ -124,7 +124,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) if(protoc_binary_SOURCE_DIR) message("Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") @@ -140,7 +140,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) if(protoc_binary_SOURCE_DIR) message("Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) - set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin") FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) @@ -281,7 +281,7 @@ if ((CPUINFO_SUPPORTED OR onnxruntime_USE_XNNPACK) AND NOT ANDROID) pytorch_clog URL ${DEP_URL_pytorch_cpuinfo} URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - SOURCE_SUBDIR deps/clog + SOURCE_SUBDIR deps/clog ) set(ONNXRUNTIME_CLOG_PROJ pytorch_clog) set(ONNXRUNTIME_CLOG_TARGET_NAME clog) From d111eed726f6009bd9c4bf3355194a3b85aabb9f Mon Sep 17 00:00:00 2001 From: Peishen Yan Date: Sat, 16 Dec 2023 00:57:07 +0800 Subject: [PATCH 08/11] [WebNN EP] Change axis to axes for argMax/argMin (#18838) In the latest spec, the axes option of WebNN's argMax and argMin requires the use of a sequence long type. Replace axis option (long type) with axes (sequence long type) for argMax and argMin. --- .../providers/webnn/builders/impl/argmax_min_op_builder.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 57a37d92335aa..5f8defe8fcb6b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -41,9 +41,11 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto select_last_index = helper.Get("select_last_index", 0); axis = HandleNegativeAxis(axis, input_rank); + emscripten::val axes = emscripten::val::array(); + axes.call("push", static_cast(axis)); emscripten::val options = emscripten::val::object(); - options.set("axis", static_cast(axis)); + options.set("axes", axes); options.set("keepDimensions", keep_dims == 1); options.set("selectLastIndex", select_last_index == 1); emscripten::val output = emscripten::val::object(); From 81ad1e6ac3149b928ccdaed9f76195a303613804 Mon Sep 17 00:00:00 2001 From: Yang Gu Date: Sat, 16 Dec 2023 00:57:48 +0800 Subject: [PATCH 09/11] [js/webgpu] Fix typo of outputShapes in profiling message (#18837) --- js/web/lib/wasm/jsep/webgpu/program-manager.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index adf0b1b2964b5..ae5bf68483b46 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -115,7 +115,7 @@ export class ProgramManager { inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; }); let outputShapes = ''; - inputTensorViews.forEach((value, i) => { + outputTensorViews.forEach((value, i) => { outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `; }); // eslint-disable-next-line no-console From 89168b830d663647c00fd74536aee52f0671f884 Mon Sep 17 00:00:00 2001 From: wirthual Date: Fri, 15 Dec 2023 09:14:02 -0800 Subject: [PATCH 10/11] Fix CI error: The workflow is not valid. .github/workflows/rust-ci.yml (Line: 27, Col: 7): Unexpected value 'ORT_RUST_STRATEGY=download' (#18836) Use colon for Env variable instead of = --- .github/workflows/rust-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 6c3f2eb0fbbe1..725c40c2ded53 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -24,7 +24,7 @@ jobs: name: Download prebuilt ONNX Runtime archive from build.rs runs-on: ubuntu-latest env: - ORT_RUST_STRATEGY=download + ORT_RUST_STRATEGY: download steps: - uses: actions/checkout@v4 - uses: ./.github/actions/rust-toolchain-setup From f52668cc68efe80197227da192d9b970fa739132 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 15 Dec 2023 09:17:47 -0800 Subject: [PATCH 11/11] Disable mlas unit test in ARM64EC build (#18747) ### Description Disable mlas unit test in ARM64EC build because the program has some link errors. We will fix the errors later. This PR only impacts Windows ARM64EC build. It has no impact on the existing build pipelines. --- cmake/onnxruntime_unittests.cmake | 95 +++++++++++++++---------------- 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index df62199dc2b42..7c8c70f913dca 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1373,56 +1373,55 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(compare_two_sessions PRIVATE ${GETOPT_LIB_WIDE} tdh Advapi32) endif() - file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS - "${TEST_SRC_DIR}/mlas/unittest/*.h" - "${TEST_SRC_DIR}/mlas/unittest/*.cpp" - ) - onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) - if(MSVC) - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" - "$<$>:/wd26409>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" - "$<$>:/utf-8>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" - "$<$>:/wd6326>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" - "$<$>:/wd26426>") - endif() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set_target_properties(onnxruntime_mlas_test PROPERTIES - XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" + if(NOT onnxruntime_target_platform STREQUAL "ARM64EC") + file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS + "${TEST_SRC_DIR}/mlas/unittest/*.h" + "${TEST_SRC_DIR}/mlas/unittest/*.cpp" ) - endif() - target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} - ${CMAKE_CURRENT_BINARY_DIR}) - target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo) - endif() - if(NOT WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) - endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs}) - endif() - - if(WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32) - endif() - if (onnxruntime_LINK_LIBATOMIC) - target_link_libraries(onnxruntime_mlas_test PRIVATE atomic) - endif() - target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads) - - set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) + if(MSVC) + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" + "$<$>:/wd26409>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + "$<$>:/utf-8>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" + "$<$>:/wd6326>") + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" + "$<$>:/wd26426>") endif() - endif() - + if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + set_target_properties(onnxruntime_mlas_test PROPERTIES + XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" + ) + endif() + target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} + ${CMAKE_CURRENT_BINARY_DIR}) + target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) + if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo) + endif() + if(NOT WIN32) + target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs}) + endif() + if(WIN32) + target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32) + endif() + if (onnxruntime_LINK_LIBATOMIC) + target_link_libraries(onnxruntime_mlas_test PRIVATE atomic) + endif() + target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads) + set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest") + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") + else() + set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + endif() + endif() +endif() # Training API Tests # Disabling training_api_test_trainer. CXXOPT generates a ton of warnings because of which nuget pipeline is failing. # TODO(askhade): Fix the warnings.