From 95307fe3d6d5cc5ee1df4741b0f958f3c598bdf8 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 11 Oct 2023 12:36:45 +0800 Subject: [PATCH] Support inplace update for PythonOp/Grad (#17687) ### Support inplace update for PythonOp/Grad This PR is based on another PR https://github.com/microsoft/onnxruntime/pull/17685's branch, to make it easier to review. With PR: PR https://github.com/microsoft/onnxruntime/pull/17685, By default all PythonOp inputs/outputs are assumed to not be inplaced, if during run, we found some inplace update happens (by checking output data address with all inputs data address), we add clone before set it as PythonOp/Grad's outputs. In this case, results are correct, but implicit copies overheads are introduced. This PR allow users to define output input reuse map, to let ORT know how to do the reuse map, avoid such unnecessary copies. --- cmake/onnxruntime_optimizer.cmake | 3 + .../core/framework/allocation_planner.cc | 32 ++++- .../python/tools/symbolic_shape_infer.py | 4 +- .../torch/custom_function_register.cc | 32 +++++ .../torch/custom_function_register.h | 13 ++ .../core/framework/torch/torch_proxy.cc | 50 ++++++++ .../core/framework/torch/torch_proxy.h | 31 +++++ .../core/graph/gradient_builder.cc | 8 ++ .../core/graph/training_op_defs.cc | 11 ++ .../core/optimizer/graph_transformer_utils.cc | 6 + .../core/optimizer/pythonop_rewriter.cc | 114 ++++++++++++++++++ .../core/optimizer/pythonop_rewriter.h | 36 ++++++ .../python/orttraining_pybind_state.cc | 38 +++++- .../_custom_autograd_function_exporter.py | 79 ++++++------ .../_custom_autograd_function_runner.py | 38 ++++-- .../ortmodule/_graph_execution_manager.py | 6 +- .../ortmodule/_zero_stage3_compatibility.py | 85 ++++++++++--- .../utils/hooks/_statistics_subscriber.py | 7 ++ .../utils/hooks/_subscriber_manager.py | 20 +++ .../utils/hooks/_zero_offload_subscriber.py | 52 +++++++- .../torch_custom_function_kernel_base.cc | 3 +- 21 files changed, 590 insertions(+), 78 deletions(-) create mode 100644 orttraining/orttraining/core/optimizer/pythonop_rewriter.cc create mode 100644 orttraining/orttraining/core/optimizer/pythonop_rewriter.h diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 3da4198573d54..baea52e84ace2 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -109,6 +109,9 @@ onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxr target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}) if (onnxruntime_ENABLE_TRAINING) target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT}) + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + onnxruntime_add_include_to_target(onnxruntime_optimizer Python::Module) + endif() endif() if (onnxruntime_ENABLE_TRITON) target_link_libraries(onnxruntime_optimizer PRIVATE nlohmann_json::nlohmann_json) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 0bf27fdf5e5dc..9556e056dedc0 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -320,7 +320,7 @@ class PlannerImpl { return false; } - const auto& alias_map = ci.kernel_def->Alias(); + const auto alias_map = GetAliasMap(node, ci); auto input_args = node.InputDefs(); for (auto& pair : alias_map) { if (pair.second == output_arg_num) { @@ -829,6 +829,34 @@ class PlannerImpl { return p_provider->GetOrtDeviceByMemType(utils::IsInputOnCpu(node, &kernel_create_info, input_index) ? OrtMemTypeCPUInput : OrtMemTypeDefault); } + std::vector> GetAliasMap(const Node& node, const KernelCreateInfo& kernel_create_info) { + ORT_ENFORCE(kernel_create_info.kernel_def != nullptr, "KernelDef is null for node: ", node.Name()); +#ifdef ENABLE_TRAINING_TORCH_INTEROP + if ((node.OpType().compare("PythonOp") == 0 || node.OpType().compare("PythonOpGrad") == 0) && + node.Domain() == kMSDomain) { + const auto& attrs = node.GetAttributes(); + auto attr_it = attrs.find("tensor_reuse_map"); + if (attr_it != attrs.end()) { + const auto& inplace_map = attr_it->second.ints(); + std::vector> alias_map; + alias_map.reserve(inplace_map.size()); + for (int i = 0; i < inplace_map.size(); ++i) { + int output_index = i; + int input_index = inplace_map[i]; + if (input_index == -1) { + // skip because no reuse for this output + continue; + } + alias_map.emplace_back(std::make_pair(input_index, output_index)); + } + return alias_map; + } + } +#endif + + return kernel_create_info.kernel_def->Alias(); + } + void GeneratePlanForWeightsHelper(const GraphViewer& graph_viewer, const InitializedTensorSet& weights, const KernelCreateInfoMap& kernel_create_info_map, @@ -1084,7 +1112,7 @@ class PlannerImpl { } bool found_reusable = false; - const auto& alias_map = ci.kernel_def->Alias(); + const auto alias_map = GetAliasMap(*node, ci); auto input_args = node->InputDefs(); for (auto* input_arg : input_args) { OrtValueIndex input_idx_global{}; diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 6b0674d3b1378..6d954bd540718 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2385,10 +2385,10 @@ def _infer_PythonOp(self, node): # noqa: N802 output_tensor_ranks = get_attribute(node, "output_tensor_ranks") assert output_tensor_ranks - from onnxruntime.training.ortmodule._custom_autograd_function_exporter import PythonOpShapeInferStore + from onnxruntime.capi._pybind_state import get_shape_inference_function func_name = get_attribute(node, "func_name").decode() - shape_inferer = PythonOpShapeInferStore.get_shape_infer(func_name) + shape_inferer = get_shape_inference_function(func_name) # Set the context output separately. # The first output is torch.autograd.Function''s context. diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.cc b/orttraining/orttraining/core/framework/torch/custom_function_register.cc index 2bf0be1d719c2..1a51da3daa27f 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.cc +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.cc @@ -95,6 +95,16 @@ void OrtTorchFunctionPool::RegisterTorchAutogradFunction( RegisterEntry(mutex_, key, backward.get(), backward_core_pool_); } +void OrtTorchFunctionPool::RegisterShapeInferenceFunction(const std::string& key, + PyObject* obj) { + RegisterEntry(mutex_, key, obj, shape_inference_function_pool_); +} + +void OrtTorchFunctionPool::RegisterInputAliasFunction(const std::string& key, + PyObject* obj) { + RegisterEntry(mutex_, key, obj, input_alias_function_pool_); +} + static void RegisterEntry( std::mutex& mutex, PyObject* obj, @@ -153,6 +163,26 @@ PyObject* OrtTorchFunctionPool::GetBackwardCore(const std::string& key) { return iter->second.get(); } +std::optional OrtTorchFunctionPool::TryGettingShapeInferenceFunction(const std::string& key) { + ORT_ENFORCE(!key.empty(), "Cannot be empty string."); + std::lock_guard lock(mutex_); + auto iter = shape_inference_function_pool_.find(key); + if (iter != shape_inference_function_pool_.end()) { + return iter->second.get(); + } + return std::nullopt; +} + +std::optional OrtTorchFunctionPool::TryGettingInputAliasFunction(const std::string& key) { + ORT_ENFORCE(!key.empty(), "Cannot be empty string."); + std::lock_guard lock(mutex_); + auto iter = input_alias_function_pool_.find(key); + if (iter != input_alias_function_pool_.end()) { + return iter->second.get(); + } + return std::nullopt; +} + void OrtTorchFunctionPool::RegisterMiscellaneousConstInput(PyObject* obj) { ORT_ENFORCE(obj, "Cannot register NULL reference input."); const void* address = static_cast(obj); @@ -205,6 +235,8 @@ void OrtTorchFunctionPool::UnRegisterGlobalFunctions() { void OrtTorchFunctionPool::UnRegisterModelSpecificFunctions() { forward_core_pool_.clear(); backward_core_pool_.clear(); + shape_inference_function_pool_.clear(); + input_alias_function_pool_.clear(); miscellaneous_const_input_pool_.clear(); } diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.h b/orttraining/orttraining/core/framework/torch/custom_function_register.h index 0dea6d036a6bd..d51cc7dadc1af 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.h +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.h @@ -34,6 +34,16 @@ class OrtTorchFunctionPool final { // 2. Caller of GetBackwardCore should not decrease the reference count of the returned object. PyObject* GetBackwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOpGrad. + // Shape inference function is used to infer output shape of a PythonOp. + void RegisterShapeInferenceFunction(const std::string& key, PyObject* obj); + // Return a borrowed reference to the stored Python function, if it exists; otherwise, return nullptr. + std::optional TryGettingShapeInferenceFunction(const std::string& key); + + // Input alias function is used to infer memory reuse map of a PythonOp. + void RegisterInputAliasFunction(const std::string& key, PyObject* obj); + // Return a borrowed reference to the stored Python function, if it exists; otherwise, return nullptr. + std::optional TryGettingInputAliasFunction(const std::string& key); + // Autograd function may take input of "non-tensor && non int/float && non int/float tuple" types. // While PythonOp running requires those inputs be there otherwise kernel execution will fail. // So during model exporting, we need register those input with this API, then a ref cnt is increased by 1, @@ -92,6 +102,9 @@ class OrtTorchFunctionPool final { std::unordered_map forward_core_pool_; std::unordered_map backward_core_pool_; + std::unordered_map shape_inference_function_pool_; + std::unordered_map input_alias_function_pool_; + std::unordered_map miscellaneous_const_input_pool_; std::unordered_map func_context_pool_; diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.cc b/orttraining/orttraining/core/framework/torch/torch_proxy.cc index 58e22f4e266ee..f36f913366a37 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.cc +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.cc @@ -372,4 +372,54 @@ void TorchProxy::Backward( returned_ortvalues); } +void TorchProxy::RunInputAliasFunction( + void* input_alias_function, + const std::string& node_proto_str, + std::vector& fw_output_to_input_alias_map, + std::vector& bw_output_to_input_alias_map) { + PyObject* input_alias_func = reinterpret_cast(input_alias_function); + ORT_ENFORCE(PyCallable_Check(input_alias_func), "input_alias_func is not callable."); + + // All arguments created for Python call will be destroyed along with PythonObjectPtr. + PythonObjectPtr args(Ort_PyTuple_New(1, "input_alias_func_arguments_tuple"), PythonObjectDeleter); + PyObject* node_proto_ptr_arg = PyBytes_FromStringAndSize(node_proto_str.c_str(), node_proto_str.size()); + Ort_PyTuple_SetItem_NoIncref(args.get(), 0, node_proto_ptr_arg, "node_proto_ptr_arg"); + + PythonObjectPtr result_ptr(PyObject_CallObject(input_alias_func, args.get()), PythonObjectDeleter); + if (PyErr_Occurred()) { + PyErr_Print(); + ORT_THROW("Python function execution fails with the above information."); + } + + bool is_tuple = PyTuple_Check(result_ptr.get()); + bool is_list = PyList_Check(result_ptr.get()); + ORT_ENFORCE(is_tuple || is_list, "Python function must return a tuple or a list. is_tuple: ", + is_tuple, ", is_list: ", is_list); + Py_ssize_t ret_tuple_size = + is_tuple ? PyTuple_Size(result_ptr.get()) : PyList_Size(result_ptr.get()); + ORT_ENFORCE(ret_tuple_size == 2, "Input alias function must return a tuple/list of size 2."); + + for (Py_ssize_t tuple_index = 0; tuple_index < ret_tuple_size; ++tuple_index) { + PyObject* alias_map = is_tuple ? PyTuple_GetItem(result_ptr.get(), tuple_index) + : PyList_GetItem(result_ptr.get(), tuple_index); + + std::vector& output_to_input_alias_map = + tuple_index == 0 ? fw_output_to_input_alias_map : bw_output_to_input_alias_map; + + bool is_elem_tuple = PyTuple_Check(alias_map); + bool is_elem_list = PyList_Check(alias_map); + + ORT_ENFORCE(is_elem_tuple || is_elem_list, "Input alias map must be a tuple or a list. is_elem_list: ", + is_elem_list, ", is_elem_tuple: ", is_elem_tuple); + Py_ssize_t output_count = is_elem_tuple ? PyTuple_Size(alias_map) : PyList_Size(alias_map); + for (Py_ssize_t output_index = 0; output_index < output_count; ++output_index) { + PyObject* input_index = + is_elem_tuple ? PyTuple_GetItem(alias_map, output_index) : PyList_GetItem(alias_map, output_index); + ORT_ENFORCE(PyLong_Check(input_index), "Alias input index must be an integer."); + int64_t alias_index_int = PyLong_AsLongLong(input_index); + output_to_input_alias_map.push_back(alias_index_int); + } + } +} + } // namespace onnxruntime::language_interop_ops::torch diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.h b/orttraining/orttraining/core/framework/torch/torch_proxy.h index aeb02bab97eea..1d5cc1dd69095 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.h +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.h @@ -2,8 +2,11 @@ // Licensed under the MIT License. #pragma once + #include #include +#include +#include #include "orttraining/core/framework/torch/python_common.h" #ifndef SHARED_PROVIDER @@ -61,6 +64,34 @@ class TorchProxy { const std::string& invoke_id, std::vector& return_args); + /** + * @brief Run given function to get output to input reuse map. + * + * @param input_alias_func Python function to run. + * The function should take a serialized PythonOp NodeProto string as input, return a tuple of two lists. + * The signature of the function should be: + * def alias_input(node_proto_str: str): + * fw_alias_map = [1, -1, -1] + * bw_alias_map = [-1, 0] + * return fw_alias_map, bw_alias_map + * @param node_proto_str The serialized PythonOp NodeProto string. + * @param fw_output_to_input_alias_map Used as returned value, return the output to input alias map for forward pass. + * For example, if the inputs of the torch.autograd.Function are [non_tensor_a, tensor_b], + * outputs are [tensor_x, tensor_y, tensor_z], and the alias map is [1, -1, -1], this is explained as: + * tensor_x is reusing the input tensor_b, tensor_y and tensor_z are not reusing any input. + * The value of alias map is 0 based input index. -1 means the output is not reusing any input. + * @param bw_output_to_input_alias_map Used as returned value, return the output to input alias map for backward pass. + * For example, if the inputs of the torch.autograd.Function are [tensor_x_grad, None, None], + * outputs are [None, tensor_b_grad], and the alias map is [-1, 0], this is explained as: + * tensor_b_grad is reusing the input tensor_x_grad. + * The value of alias map is 0 based grad input index. -1 means the output is not reusing any input. + */ + void RunInputAliasFunction( + void* input_alias_func, + const std::string& node_proto_str, + std::vector& fw_output_to_input_alias_map, + std::vector& bw_output_to_input_alias_map); + private: TorchProxy(){}; ~TorchProxy(){}; diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index b3da4f3977ff2..133cab71f2b1c 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1848,6 +1848,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) { "PythonOpGrad requiring gradient output count mismatch."); attrs.push_back(MakeAttribute("output_tensor_requires_grads", bw_tensor_output_requires_grads)); + // Copy bw_tensor_reuse_map attribute from PythonOp to PythonOpGrad if it is present. + auto attr_it = src_attrs.find("bw_tensor_reuse_map"); + if (attr_it != src_attrs.end()) { + std::vector tensor_output_to_tensor_input_reuse_map(attr_it->second.ints().begin(), + attr_it->second.ints().end()); + attrs.push_back(MakeAttribute("tensor_reuse_map", tensor_output_to_tensor_input_reuse_map)); + } + if (src_attrs.find("comment") != src_attrs.end() && utils::HasString(src_attrs.at("comment"))) { attrs.push_back(MakeAttribute("comment", src_attrs.at("comment").s())); } diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 5cd29303c3639..cfc79455c43ed 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3918,6 +3918,17 @@ Return true if all elements are true and false otherwise. "- the output 2 reuses the input 0.", AttributeProto::INTS, false) + .Attr( + "bw_tensor_reuse_map", + "Used for backward op only." + "A int array indicating whether output at each index is reusing specific input or now." + "If the given index is -1, it means the output is not reusing any input." + "For example, there are 3 inputs (including ctx) and 2 outputs, tensor_reuse_map = [2, 1] means" + "- the output 0 reuses the input 2." + "- the output 1 reuses the input 1." + "Be noted: the input 0 is ctx.", + AttributeProto::INTS, + false) .Attr( "training_mode", "Indicate if the model is exported in training_mode, by default, False.", diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 6b566ed064aa4..e5c65b2a96d8c 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -68,6 +68,9 @@ #include "core/optimizer/pre_shape_node_elimination.h" #include "orttraining/core/optimizer/compute_optimizer/padding_elimination.h" #include "orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h" +#ifdef ENABLE_TRAINING_TORCH_INTEROP +#include "orttraining/core/optimizer/pythonop_rewriter.h" +#endif namespace onnxruntime { namespace training { @@ -106,6 +109,9 @@ std::vector> GeneratePreTrainingTransformers( ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); +#ifdef ENABLE_TRAINING_TORCH_INTEROP + ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); +#endif // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by diff --git a/orttraining/orttraining/core/optimizer/pythonop_rewriter.cc b/orttraining/orttraining/core/optimizer/pythonop_rewriter.cc new file mode 100644 index 0000000000000..e1cd71958bed1 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/pythonop_rewriter.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_TORCH_INTEROP + +#include +#include +#include + +#include "orttraining/core/optimizer/pythonop_rewriter.h" + +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" +#include "orttraining/core/framework/torch/torch_proxy.h" +#include "orttraining/core/framework/torch/custom_function_register.h" + +namespace onnxruntime { + +Status PythonOpRewriter::Apply(Graph&, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + bool modified = false; + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "PythonOp", {1}, kMSDomain) && + node.GetAttributes().find("tensor_reuse_map") == node.GetAttributes().end()) { + auto func_name = static_cast(node.GetAttributes().at("func_name").s()); + std::optional input_alias_function = + language_interop_ops::torch::OrtTorchFunctionPool::GetInstance().TryGettingInputAliasFunction(func_name); + if (input_alias_function.has_value()) { + // Serialize node proto to string + ONNX_NAMESPACE::NodeProto node_proto; + node.ToProto(node_proto); + std::string node_proto_str; + node_proto.SerializeToString(&node_proto_str); + + // Call input alias function + std::vector fw_all_output_to_tensor_input_reuse_map; + std::vector bw_all_output_to_tensor_input_reuse_map; + language_interop_ops::torch::TorchProxy::GetInstance().RunInputAliasFunction( + static_cast(input_alias_function.value()), + node_proto_str, + fw_all_output_to_tensor_input_reuse_map, + bw_all_output_to_tensor_input_reuse_map); + + auto input_convention = static_cast(node.GetAttributes().at("input_convention").s()); + { + // Handle forward input alias map. + std::vector fw_tensor_output_to_tensor_input_reuse_map = + std::vector((node.OutputDefs().size()), -1); + + // Map input index from `global` input index to `tensor` input index, because node.InputDefs() only contains + // tensor inputs. + std::unordered_map position_to_tensor_index; + int64_t tensor_index = 0; + const size_t all_input_count = input_convention.size(); + position_to_tensor_index.reserve(all_input_count); + for (size_t i = 0; i < all_input_count; ++i) { + if (input_convention[i] == 'd') { + position_to_tensor_index[i] = tensor_index; + ++tensor_index; + } + } + + for (size_t i = 1; i < fw_tensor_output_to_tensor_input_reuse_map.size(); ++i) { + if (fw_all_output_to_tensor_input_reuse_map[i - 1] != -1) { + ORT_ENFORCE(fw_all_output_to_tensor_input_reuse_map[i - 1] < static_cast(all_input_count), + "PythonOp input alias function output index out of range. func_name: ", func_name, " ", + fw_all_output_to_tensor_input_reuse_map[i - 1], " >= ", all_input_count); + fw_tensor_output_to_tensor_input_reuse_map[i] = + position_to_tensor_index.at(fw_all_output_to_tensor_input_reuse_map[i - 1]); + } + } + + node.AddAttribute("tensor_reuse_map", fw_tensor_output_to_tensor_input_reuse_map); + } + + { + // Handle backward input alias map. + auto& output_convention = input_convention; + ORT_ENFORCE(bw_all_output_to_tensor_input_reuse_map.size() == output_convention.size(), + "PythonOpGrad input alias function output count mismatch. func_name: ", func_name, " ", + bw_all_output_to_tensor_input_reuse_map.size(), " != ", output_convention.size()); + + std::vector bw_tensor_output_to_tensor_input_reuse_map = + std::vector(node.InputDefs().size(), -1); + size_t tensor_output_index = 0; + for (size_t i = 0; i < output_convention.size(); ++i) { + if (output_convention[i] == 'd') { + ORT_ENFORCE(tensor_output_index < bw_tensor_output_to_tensor_input_reuse_map.size(), + "PythonOpGrad input alias function output count mismatch. func_name: ", func_name, " ", + tensor_output_index, " >= ", bw_tensor_output_to_tensor_input_reuse_map.size()); + // input index shift by 1 to skip the context + bw_tensor_output_to_tensor_input_reuse_map[tensor_output_index] = + bw_all_output_to_tensor_input_reuse_map[i] == -1 ? -1 : bw_all_output_to_tensor_input_reuse_map[i] + 1; + ++tensor_output_index; + } + } + node.AddAttribute("bw_tensor_reuse_map", bw_tensor_output_to_tensor_input_reuse_map); + } + + modified = true; + } + } + + if (modified) + rule_effect = RewriteRuleEffect::kUpdatedCurrentNode; + + return Status::OK(); +} + +bool PythonOpRewriter::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const { + return true; +} + +} // namespace onnxruntime + +#endif diff --git a/orttraining/orttraining/core/optimizer/pythonop_rewriter.h b/orttraining/orttraining/core/optimizer/pythonop_rewriter.h new file mode 100644 index 0000000000000..5534b190979f0 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/pythonop_rewriter.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_TORCH_INTEROP + +#pragma once + +#include +#include +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +This transformer is to add schema supplementary for PythonOp. + +Currently, add memory reuse output to input map as an attribute, if users registered alias input function +in `OrtTorchFunctionPool`. +*/ + +class PythonOpRewriter : public RewriteRule { + public: + PythonOpRewriter() noexcept : RewriteRule("PythonOpRewriter") {} + + std::vector TargetOpTypes() const noexcept override { + return {"PythonOp"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime +#endif diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 35d9755ba0ba7..a08e8bee99cee 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -533,12 +533,44 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_UNUSED_PARAMETER(obj); #endif }); - m.def("register_torch_autograd_function", [](std::string key, py::object obj) -> void { + m.def("register_torch_autograd_function", [](std::string function_full_qual_name, py::object obj) -> void { #ifdef ENABLE_TRAINING_TORCH_INTEROP auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); - pool.RegisterTorchAutogradFunction(key, obj.ptr()); + pool.RegisterTorchAutogradFunction(function_full_qual_name, obj.ptr()); #else - ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(function_full_qual_name); + ORT_UNUSED_PARAMETER(obj); +#endif + }); + m.def("register_shape_inference_function", [](std::string function_full_qual_name, py::object obj) -> void { +#ifdef ENABLE_TRAINING_TORCH_INTEROP + auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); + pool.RegisterShapeInferenceFunction(function_full_qual_name, obj.ptr()); +#else + ORT_UNUSED_PARAMETER(function_full_qual_name); + ORT_UNUSED_PARAMETER(obj); +#endif + }); + m.def("get_shape_inference_function", [](std::string function_full_qual_name) -> py::object { +#ifdef ENABLE_TRAINING_TORCH_INTEROP + auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); + auto py_object = pool.TryGettingShapeInferenceFunction(function_full_qual_name); + if (py_object.has_value()) { + Py_INCREF(py_object.value()); + return py::reinterpret_steal(py_object.value()); + } +#else + ORT_UNUSED_PARAMETER(function_full_qual_name); +#endif + return py::none(); + }); + + m.def("register_input_alias_function", [](std::string function_full_qual_name, py::object obj) -> void { +#ifdef ENABLE_TRAINING_TORCH_INTEROP + auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); + pool.RegisterInputAliasFunction(function_full_qual_name, obj.ptr()); +#else + ORT_UNUSED_PARAMETER(function_full_qual_name); ORT_UNUSED_PARAMETER(obj); #endif }); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index c6edaf7cd3a2c..8c5469740d9bd 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import sys -from typing import Callable, ClassVar, Dict, Optional import torch import torch.utils.checkpoint @@ -12,7 +11,12 @@ from packaging import version from torch.onnx import symbolic_helper -from onnxruntime.capi._pybind_state import register_miscellaneous_const_input, register_torch_autograd_function +from onnxruntime.capi._pybind_state import ( + register_input_alias_function, + register_miscellaneous_const_input, + register_shape_inference_function, + register_torch_autograd_function, +) from onnxruntime.training import ortmodule from onnxruntime.training.utils import pytorch_dtype_to_onnx @@ -21,45 +25,46 @@ from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version -class PythonOpShapeInferStore: - """A class to store shape inference functions for torch.autograd.Function.""" +def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None: + """Register a shape inference function for a torch.autograd.Function if there is staticmethod + "infer_shape" defined. - _CLASS_MAP: ClassVar[Dict[str, Callable]] = {} + The signature of the shape inference function should be: + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_output_shapes = [] + tensor_output_dtypes = [] + ... + return tensor_output_shapes, tensor_output_dtypes - @classmethod - def register(cls, kclass: torch.autograd.Function) -> None: - """Register a shape inference function for a torch.autograd.Function if there is staticmethod - "infer_shape" defined. + The tensor_input_shapes and tensor_input_dtypes are lists of shapes and dtypes of the input tensors. + The tensor_output_shapes and tensor_output_dtypes are lists of shapes and dtypes of the output tensors. + Be noted: we only pass in tensor inputs, and return tensor outputs, non-tensor inputs/outputs are ignored. - The signature of the shape inference function should be: - @staticmethod - def infer_shape( - node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: - tensor_output_shapes = [] - tensor_output_dtypes = [] - ... - return tensor_output_shapes, tensor_output_dtypes - The tensor_input_shapes and tensor_input_dtypes are lists of shapes and dtypes of the input tensors. - The tensor_output_shapes and tensor_output_dtypes are lists of shapes and dtypes of the output tensors. - Be noted: we only pass in tensor inputs, and return tensor outputs, non-tensor inputs/outputs are ignored. + The signature of the alias input function should be: + @staticmethod + def alias_input(node_proto_str: str) -> Tuple[List[int], List[int]]: + fw_alias_map = [1, -1, -1] + bw_alias_map = [-1, 0] + return fw_alias_map, bw_alias_map - """ - kclass_name = get_fully_qualified_class_name(kclass) - if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP: - cls._CLASS_MAP[kclass_name] = kclass.infer_shape + The alias input function should return a tuple of two lists: + - The first list is the forward alias map, its length is equal to the number of all outputs of the node. + - The second list is the backward alias map, its length is equal to the number of all inputs + (tensor and non-tensor) of the node. - @classmethod - def register_func(cls, name: str, func: Callable) -> None: - """Register a shape inference function for a torch.autograd.Function by name.""" - cls._CLASS_MAP[name] = func + """ + kclass_name = get_fully_qualified_class_name(kclass) + if hasattr(kclass, "infer_shape"): + register_shape_inference_function(kclass_name, kclass.infer_shape) - @classmethod - def get_shape_infer(cls, name: str) -> Optional[Callable]: - return cls._CLASS_MAP.get(name, None) + if hasattr(kclass, "alias_input"): + register_input_alias_function(kclass_name, kclass.alias_input) """ @@ -299,8 +304,8 @@ def _export_pt_1_10(g, n, *args, **kwargs): # Register function with class names. register_torch_autograd_function(func_full_qual_name, func_class) - # Register shape inference function. - PythonOpShapeInferStore.register(func_class) + register_custom_function_schema_supplementary(func_class) + return returned_args except Exception as e: sys.stdout.flush() @@ -327,7 +332,7 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model op_name_prefix = kclass_name break - node.name = f"{op_name_prefix}_id_{index}" + node.name = f"{op_name_prefix}_id_{index}" index += 1 return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index fd791f21b4d22..b9318033a3d53 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -76,10 +76,10 @@ def __init__(self, kernel_invoke_id: str): def _process_inplace_outputs( kernel_info: CustomFuncOpKernelInfo, func_name: str, - input_tensors_of_kernel_run: List[torch.Tensor], + input_tensors_of_kernel_run: Dict[int, Union[torch.Tensor, None]], all_outputs_of_kernel_run: List[Union[torch.Tensor, any]], all_outputs_to_tensor_inputs_reuse_map: List[int], - raw_input_tensors_used_inplace: Dict[int, torch.Tensor], + raw_input_tensors_used_inplace: Dict[int, Union[torch.Tensor, None]], is_backward=False, ): """Special handling for in-place reusing in forward or backward. @@ -87,12 +87,12 @@ def _process_inplace_outputs( Args: kernel_info: kernel-specific information. func_name: name of the autograd.Function. - input_tensors_of_kernel_run: input tensors used to run the autograd.Function forward/backward. + input_tensors_of_kernel_run: all tensor input tensors used to run the autograd.Function forward/backward. all_outputs_of_kernel_run: all outputs of the autograd.Function forward/backward. all_outputs_to_tensor_inputs_reuse_map: a list of the same length of kernel outputs, each element representing which input index it is reusing. If there is no reuse, the value is -1. raw_input_tensors_used_inplace: a dict of raw input tensors marked as inplace in - `all_outputs_to_tensor_inputs_reuse_map`, the key is the input index, value is the raw input tensor. + `all_outputs_to_tensor_inputs_reuse_map`, the key is the tensor input index, value is the raw input tensor. is_backward: indicates if this is backward or forward. Procedures: @@ -127,7 +127,9 @@ def _process_inplace_outputs( """ log_prefix = f"{func_name}->{'Backward' if is_backward else 'Forward'}: " - input_tensor_address_list = [t.data_ptr() for t in input_tensors_of_kernel_run] + input_tensor_address_list = [ + t.data_ptr() if isinstance(t, torch.Tensor) else -1 for t in input_tensors_of_kernel_run.values() + ] if is_backward: input_tensor_address_list = [-1, *input_tensor_address_list] # skip the context input @@ -161,6 +163,14 @@ def _process_inplace_outputs( if inplace_index == detected_inplace_index: continue + if ( + inplace_index in raw_input_tensors_used_inplace + and raw_input_tensors_used_inplace[inplace_index] is None + ): + # Use specified inplace input index, but the input tensor is None, which means the input is not + # a tensor, so we don't do further checks. + continue + # If users register inplace_map (alloc planner will do buffer reuse), # but detected inplace_map indicates it is NO inplace reusing, we raise an error. if inplace_index != -1 and detected_inplace_index == -1: @@ -210,7 +220,8 @@ def _process_inplace_outputs( ): for raw_tensor_input_index, raw_input_tensor in raw_input_tensors_used_inplace.items(): # raw_input_tensor can be None for backward run, but backward won't go here. - assert isinstance(raw_input_tensor, torch.Tensor) + if not isinstance(raw_input_tensor, torch.Tensor): + continue # We did not do the check with tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty # because even for those tensor indices not in @@ -236,8 +247,8 @@ def _process_inplace_outputs( # Only need a copy once. raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index]) _log_warning( - f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}." - "Provide output to input reuse mapping to avoid the copy overhead." + f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. " + f"{'Provide output to input reuse mapping to avoid the copy overhead.' if not is_first_time_init else ''}" ) copied = True @@ -531,7 +542,7 @@ def call_python_forward_function( _process_inplace_outputs( kernel_info, func_name, - input_tensors_used_for_fw_run.values(), + input_tensors_used_for_fw_run, final_rets, inplace_map, raw_input_tensors_used_inplace, @@ -624,6 +635,10 @@ def wrap_all_outputs(result): wrapped_arg = torch.zeros(shape, device=device, dtype=dtype) else: wrapped_arg = arg + + if grad_input_index in inplace_map: + raw_input_tensors_used_inplace[tensor_input_index] = arg + else: # Assume it's a DLPack tensor# and convert it to PyTorch tensor. wrapped_arg = from_dlpack(arg) @@ -631,7 +646,8 @@ def wrap_all_outputs(result): if grad_input_index in inplace_map: raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg - input_tensors_used_for_bw_run[tensor_input_index] = wrapped_arg + # This may include None values. + input_tensors_used_for_bw_run[tensor_input_index] = wrapped_arg if wrapped_arg is not None: # Only requires gradient when running under training mode @@ -662,7 +678,7 @@ def wrap_all_outputs(result): _process_inplace_outputs( kernel_info, func_name, - input_tensors_used_for_bw_run.values(), + input_tensors_used_for_bw_run, result, inplace_map, raw_input_tensors_used_inplace, diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index dfaac5f0fa836..5e8805bfddb4d 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -143,9 +143,6 @@ def __init__( self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: # Cannot toggle feature enabling/disabling after the first time enabled. - from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params - - self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module) configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) @@ -420,8 +417,11 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu exported_model = post_process_enabling_autograd_function(exported_model) if self._runtime_options.enable_zero_stage3_support: + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module) exported_model = post_processing_enable_zero_stage3_compat( exported_model, self._zero_stage3_param_map, diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index 301071f6de44c..d0dea66fda2a7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -9,10 +9,14 @@ import torch from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper -from onnxruntime.capi._pybind_state import register_torch_autograd_function +from onnxruntime.capi._pybind_state import ( + register_input_alias_function, + register_shape_inference_function, + register_torch_autograd_function, +) from onnxruntime.training.utils import pytorch_dtype_to_onnx -from ._custom_autograd_function_exporter import PythonOpShapeInferStore +from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary from ._utils import get_fully_qualified_class_name STAGE3_PULL_WEIGHT_TRIGGER_NAME = "pull_weight_trigger" @@ -36,6 +40,8 @@ def post_processing_enable_zero_stage3_compat( # Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3. _register_symbolic_shape_infer_functions() + _register_alias_input_functions() + # Create weight retrieving function using zero_stage3_named_params. func_full_qual_name = _create_weight_retrieval_function(zero_stage3_named_params) @@ -69,7 +75,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction - prefowrad_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) + pre_forward_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) # Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. for graph_input in exported_model.graph.input: @@ -87,7 +93,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: continue func_name = _get_func_name(c) - if func_name == prefowrad_function_name: + if func_name == pre_forward_function_name: assert ( pre_forward_pythonop_node is None ), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen" @@ -98,6 +104,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: "Fail to find ORTZeROOffloadPreForwardFunction for partitioned param: " + graph_input.name ) + pull_weight_trigger_input_name = _get_param_pull_trigger_name(graph_input.name) index_offset_on_python_op_input = [] for i, input_name in enumerate(pre_forward_pythonop_node.input): if input_name == graph_input.name: @@ -105,21 +112,32 @@ def _get_func_name(node: NodeProto) -> Optional[str]: assert ( len(index_offset_on_python_op_input) == 1 - ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input}" + ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input} for node {pre_forward_pythonop_node.name}, input {graph_input.name}, {pre_forward_pythonop_node.input}" reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input) - new_input_name = _get_param_pull_trigger_name(graph_input.name) - pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = new_input_name + + pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = pull_weight_trigger_input_name _update_python_op_input_related_attributes( pre_forward_pythonop_node, - new_input_name, + pull_weight_trigger_input_name, len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE), # new rank STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, # new data type ) output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output) - pre_forward_pythonop_node.output[output_index] = graph_input.name + + ready_weight_name = f"ready_{graph_input.name}" + pre_forward_pythonop_node.output[output_index] = ready_weight_name + + # Update consumer's input to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. + for c in consumers: + new_inputs = [c_input for c_input in c.input] + for c_input_index in range(len(c.input)): + if c.input[c_input_index] == graph_input.name: + new_inputs[c_input_index] = ready_weight_name + del c.input[:] + c.input.extend(new_inputs) # If the consumer of original `graph_input.name` is PythonOp, we need also update its attributes because now # `graph_input.name` as output of pre_forward_pythonop_node, is full-sized parameter, the rank might differ @@ -186,11 +204,13 @@ def infer_shape( tensor_output_dtypes = [ tensor_input_dtypes[0], ] * param_count + return tensor_output_shapes, tensor_output_dtypes func_full_qual_name = get_fully_qualified_class_name(WeightRetrievalFunction) register_torch_autograd_function(func_full_qual_name, WeightRetrievalFunction) - PythonOpShapeInferStore.register(WeightRetrievalFunction) + + register_custom_function_schema_supplementary(WeightRetrievalFunction) return func_full_qual_name @@ -206,10 +226,10 @@ def _simple_pass_through_infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes - PythonOpShapeInferStore.register_func( + register_shape_inference_function( "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape ) - PythonOpShapeInferStore.register_func( + register_shape_inference_function( "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape ) @@ -225,9 +245,36 @@ def _linear_infer_shape( output_shape[-1] = shape2[-2] return [output_shape], [tensor_input_dtypes[0]] - PythonOpShapeInferStore.register_func( - "deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape - ) + register_shape_inference_function("deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape) + + +def _register_alias_input_functions(): + """This function is used to register symbolic shape inference functions for PythonOp used in + DeepSpeed ZeRO stage3.""" + + def _alias_input(node_proto_str: str): + node: NodeProto = NodeProto() + node.ParseFromString(node_proto_str) + non_tensor_fw_input_count = 2 + + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + + return fw_alias_map, bw_alias_map + + register_input_alias_function("deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _alias_input) + register_input_alias_function("deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _alias_input) def _create_weight_retrieval_pythonop( @@ -276,7 +323,9 @@ def _create_weight_retrieval_pythonop( return new_input, weight_pull_node -def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, new_rank: int, new_dtype: int): +def _update_python_op_input_related_attributes( + node: NodeProto, input_name: str, new_rank: int, new_dtype: torch.onnx.TensorProtoDataType +): """This function is used to update PythonOp's input related attributes, e.g. input_tensor_ranks and input_tensor_types. @@ -284,7 +333,7 @@ def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, node (NodeProto): The PythonOp node. input_name (str): The input name to be updated. new_rank (int): The new rank of the input, to be used in input_tensor_ranks. - new_dtype (int): The new data type of the input, to be used in input_tensor_types. + new_dtype (torch.onnx.TensorProtoDataType): The new data type of the input, to be used in input_tensor_types. """ input_tensor_ranks = None input_tensor_dtypes = None @@ -304,7 +353,7 @@ def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, for index, node_input_name in enumerate(node.input): if node_input_name == input_name: input_tensor_ranks[index] = new_rank - input_tensor_dtypes[index] = new_dtype + input_tensor_dtypes[index] = int(new_dtype) node.attribute.remove(rank_attr) node.attribute.remove(dtype_attr) diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index db1c69cf95ba4..c5be17236ac06 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -93,6 +93,13 @@ def infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes + @staticmethod + def alias_input(node_proto_str: str): + fw_alias_map = [3] + bw_alias_map = [-1] * 6 + bw_alias_map[3] = 0 + return fw_alias_map, bw_alias_map + class StatisticsSubscriber(SubscriberBase): """ diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index b2bc64be42fc1..c9c06dabab4de 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -68,6 +68,26 @@ def infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes + @staticmethod + def alias_input(node_proto_str: str): + node = onnx.NodeProto() + node.ParseFromString(node_proto_str) + non_tensor_fw_input_count = 1 + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + return fw_alias_map, bw_alias_map + class SubscriberManager: """This class is used to manage all the subscribers and register subscribers' custom actions as PyTorch hooks diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 2b01547a675ec..b1cb5c19e8979 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -324,12 +324,41 @@ def infer_shape( start_offset = len(tensor_input_shapes) - len(partitioned_params) for index, param in enumerate(partitioned_params): tensor_output_shapes[start_offset + index] = list(param.ds_shape) - tensor_output_dtypes[start_offset + index] = pytorch_dtype_to_onnx(param.dtype) + tensor_output_dtypes[start_offset + index] = int(pytorch_dtype_to_onnx(param.dtype)) assert len(tensor_output_shapes) == len(tensor_input_shapes) assert len(tensor_output_dtypes) == len(tensor_input_dtypes) return tensor_output_shapes, tensor_output_dtypes + @staticmethod + def alias_input(node_proto_str: str): + node = onnx.NodeProto() + node.ParseFromString(node_proto_str) + input_pointer_scalars_attr_name = "input_pointer_scalars" + found = [attr for attr in node.attribute if attr.name == input_pointer_scalars_attr_name] + assert len(found) == 1 + input_pointer_scalars = found[0].ints + # Restore the nn.Module from the pointer. + module = ctypes.cast(input_pointer_scalars[0], ctypes.py_object).value + partitioned_params = _get_params_for_current_module(module) + + non_tensor_fw_input_count = 6 + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count - len(partitioned_params)): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map) - len(partitioned_params)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + + return fw_alias_map, bw_alias_map + class ORTZeROOffloadPostForwardFunction(torch.autograd.Function): @staticmethod @@ -384,6 +413,27 @@ def infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes + @staticmethod + def alias_input(node_proto_str: str): + node = onnx.NodeProto() + node.ParseFromString(node_proto_str) + non_tensor_fw_input_count = 4 + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + + return fw_alias_map, bw_alias_map + class _ZeROOffloadFunctions: def __init__(self, one_time_init: _ZeROOffloadOneTimeInitializer, offloader) -> None: diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index a31fa5d850e59..41f4a41a7c38a 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -294,7 +294,8 @@ void PythonOpBase::SetContextOutput(OpKernelContext* context, void* diff_ctx) co void PythonOpBase::SetOtherOutputs(OpKernelContext* context, std::vector& returned_ortvalues) const { auto* ctx_internal = reinterpret_cast(context); - ORT_ENFORCE(returned_ortvalues.size() == all_output_to_tensor_input_reuse_map_.size() - 1, "PythonOp output count mismatch inplace map count.", + ORT_ENFORCE(returned_ortvalues.size() == all_output_to_tensor_input_reuse_map_.size() - 1, + "PythonOp output count mismatch inplace map count.", returned_ortvalues.size(), " != ", all_output_to_tensor_input_reuse_map_.size() - 1); for (size_t i = 0; i < returned_ortvalues.size(); ++i) { size_t output_index = i + 1;