diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 3da4198573d54..baea52e84ace2 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -109,6 +109,9 @@ onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxr target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}) if (onnxruntime_ENABLE_TRAINING) target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT}) + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + onnxruntime_add_include_to_target(onnxruntime_optimizer Python::Module) + endif() endif() if (onnxruntime_ENABLE_TRITON) target_link_libraries(onnxruntime_optimizer PRIVATE nlohmann_json::nlohmann_json) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 0bf27fdf5e5dc..9556e056dedc0 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -320,7 +320,7 @@ class PlannerImpl { return false; } - const auto& alias_map = ci.kernel_def->Alias(); + const auto alias_map = GetAliasMap(node, ci); auto input_args = node.InputDefs(); for (auto& pair : alias_map) { if (pair.second == output_arg_num) { @@ -829,6 +829,34 @@ class PlannerImpl { return p_provider->GetOrtDeviceByMemType(utils::IsInputOnCpu(node, &kernel_create_info, input_index) ? OrtMemTypeCPUInput : OrtMemTypeDefault); } + std::vector> GetAliasMap(const Node& node, const KernelCreateInfo& kernel_create_info) { + ORT_ENFORCE(kernel_create_info.kernel_def != nullptr, "KernelDef is null for node: ", node.Name()); +#ifdef ENABLE_TRAINING_TORCH_INTEROP + if ((node.OpType().compare("PythonOp") == 0 || node.OpType().compare("PythonOpGrad") == 0) && + node.Domain() == kMSDomain) { + const auto& attrs = node.GetAttributes(); + auto attr_it = attrs.find("tensor_reuse_map"); + if (attr_it != attrs.end()) { + const auto& inplace_map = attr_it->second.ints(); + std::vector> alias_map; + alias_map.reserve(inplace_map.size()); + for (int i = 0; i < inplace_map.size(); ++i) { + int output_index = i; + int input_index = inplace_map[i]; + if (input_index == -1) { + // skip because no reuse for this output + continue; + } + alias_map.emplace_back(std::make_pair(input_index, output_index)); + } + return alias_map; + } + } +#endif + + return kernel_create_info.kernel_def->Alias(); + } + void GeneratePlanForWeightsHelper(const GraphViewer& graph_viewer, const InitializedTensorSet& weights, const KernelCreateInfoMap& kernel_create_info_map, @@ -1084,7 +1112,7 @@ class PlannerImpl { } bool found_reusable = false; - const auto& alias_map = ci.kernel_def->Alias(); + const auto alias_map = GetAliasMap(*node, ci); auto input_args = node->InputDefs(); for (auto* input_arg : input_args) { OrtValueIndex input_idx_global{}; diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 6b0674d3b1378..6d954bd540718 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2385,10 +2385,10 @@ def _infer_PythonOp(self, node): # noqa: N802 output_tensor_ranks = get_attribute(node, "output_tensor_ranks") assert output_tensor_ranks - from onnxruntime.training.ortmodule._custom_autograd_function_exporter import PythonOpShapeInferStore + from onnxruntime.capi._pybind_state import get_shape_inference_function func_name = get_attribute(node, "func_name").decode() - shape_inferer = PythonOpShapeInferStore.get_shape_infer(func_name) + shape_inferer = get_shape_inference_function(func_name) # Set the context output separately. # The first output is torch.autograd.Function''s context. diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.cc b/orttraining/orttraining/core/framework/torch/custom_function_register.cc index 2bf0be1d719c2..1a51da3daa27f 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.cc +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.cc @@ -95,6 +95,16 @@ void OrtTorchFunctionPool::RegisterTorchAutogradFunction( RegisterEntry(mutex_, key, backward.get(), backward_core_pool_); } +void OrtTorchFunctionPool::RegisterShapeInferenceFunction(const std::string& key, + PyObject* obj) { + RegisterEntry(mutex_, key, obj, shape_inference_function_pool_); +} + +void OrtTorchFunctionPool::RegisterInputAliasFunction(const std::string& key, + PyObject* obj) { + RegisterEntry(mutex_, key, obj, input_alias_function_pool_); +} + static void RegisterEntry( std::mutex& mutex, PyObject* obj, @@ -153,6 +163,26 @@ PyObject* OrtTorchFunctionPool::GetBackwardCore(const std::string& key) { return iter->second.get(); } +std::optional OrtTorchFunctionPool::TryGettingShapeInferenceFunction(const std::string& key) { + ORT_ENFORCE(!key.empty(), "Cannot be empty string."); + std::lock_guard lock(mutex_); + auto iter = shape_inference_function_pool_.find(key); + if (iter != shape_inference_function_pool_.end()) { + return iter->second.get(); + } + return std::nullopt; +} + +std::optional OrtTorchFunctionPool::TryGettingInputAliasFunction(const std::string& key) { + ORT_ENFORCE(!key.empty(), "Cannot be empty string."); + std::lock_guard lock(mutex_); + auto iter = input_alias_function_pool_.find(key); + if (iter != input_alias_function_pool_.end()) { + return iter->second.get(); + } + return std::nullopt; +} + void OrtTorchFunctionPool::RegisterMiscellaneousConstInput(PyObject* obj) { ORT_ENFORCE(obj, "Cannot register NULL reference input."); const void* address = static_cast(obj); @@ -205,6 +235,8 @@ void OrtTorchFunctionPool::UnRegisterGlobalFunctions() { void OrtTorchFunctionPool::UnRegisterModelSpecificFunctions() { forward_core_pool_.clear(); backward_core_pool_.clear(); + shape_inference_function_pool_.clear(); + input_alias_function_pool_.clear(); miscellaneous_const_input_pool_.clear(); } diff --git a/orttraining/orttraining/core/framework/torch/custom_function_register.h b/orttraining/orttraining/core/framework/torch/custom_function_register.h index 0dea6d036a6bd..d51cc7dadc1af 100644 --- a/orttraining/orttraining/core/framework/torch/custom_function_register.h +++ b/orttraining/orttraining/core/framework/torch/custom_function_register.h @@ -34,6 +34,16 @@ class OrtTorchFunctionPool final { // 2. Caller of GetBackwardCore should not decrease the reference count of the returned object. PyObject* GetBackwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOpGrad. + // Shape inference function is used to infer output shape of a PythonOp. + void RegisterShapeInferenceFunction(const std::string& key, PyObject* obj); + // Return a borrowed reference to the stored Python function, if it exists; otherwise, return nullptr. + std::optional TryGettingShapeInferenceFunction(const std::string& key); + + // Input alias function is used to infer memory reuse map of a PythonOp. + void RegisterInputAliasFunction(const std::string& key, PyObject* obj); + // Return a borrowed reference to the stored Python function, if it exists; otherwise, return nullptr. + std::optional TryGettingInputAliasFunction(const std::string& key); + // Autograd function may take input of "non-tensor && non int/float && non int/float tuple" types. // While PythonOp running requires those inputs be there otherwise kernel execution will fail. // So during model exporting, we need register those input with this API, then a ref cnt is increased by 1, @@ -92,6 +102,9 @@ class OrtTorchFunctionPool final { std::unordered_map forward_core_pool_; std::unordered_map backward_core_pool_; + std::unordered_map shape_inference_function_pool_; + std::unordered_map input_alias_function_pool_; + std::unordered_map miscellaneous_const_input_pool_; std::unordered_map func_context_pool_; diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.cc b/orttraining/orttraining/core/framework/torch/torch_proxy.cc index 58e22f4e266ee..f36f913366a37 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.cc +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.cc @@ -372,4 +372,54 @@ void TorchProxy::Backward( returned_ortvalues); } +void TorchProxy::RunInputAliasFunction( + void* input_alias_function, + const std::string& node_proto_str, + std::vector& fw_output_to_input_alias_map, + std::vector& bw_output_to_input_alias_map) { + PyObject* input_alias_func = reinterpret_cast(input_alias_function); + ORT_ENFORCE(PyCallable_Check(input_alias_func), "input_alias_func is not callable."); + + // All arguments created for Python call will be destroyed along with PythonObjectPtr. + PythonObjectPtr args(Ort_PyTuple_New(1, "input_alias_func_arguments_tuple"), PythonObjectDeleter); + PyObject* node_proto_ptr_arg = PyBytes_FromStringAndSize(node_proto_str.c_str(), node_proto_str.size()); + Ort_PyTuple_SetItem_NoIncref(args.get(), 0, node_proto_ptr_arg, "node_proto_ptr_arg"); + + PythonObjectPtr result_ptr(PyObject_CallObject(input_alias_func, args.get()), PythonObjectDeleter); + if (PyErr_Occurred()) { + PyErr_Print(); + ORT_THROW("Python function execution fails with the above information."); + } + + bool is_tuple = PyTuple_Check(result_ptr.get()); + bool is_list = PyList_Check(result_ptr.get()); + ORT_ENFORCE(is_tuple || is_list, "Python function must return a tuple or a list. is_tuple: ", + is_tuple, ", is_list: ", is_list); + Py_ssize_t ret_tuple_size = + is_tuple ? PyTuple_Size(result_ptr.get()) : PyList_Size(result_ptr.get()); + ORT_ENFORCE(ret_tuple_size == 2, "Input alias function must return a tuple/list of size 2."); + + for (Py_ssize_t tuple_index = 0; tuple_index < ret_tuple_size; ++tuple_index) { + PyObject* alias_map = is_tuple ? PyTuple_GetItem(result_ptr.get(), tuple_index) + : PyList_GetItem(result_ptr.get(), tuple_index); + + std::vector& output_to_input_alias_map = + tuple_index == 0 ? fw_output_to_input_alias_map : bw_output_to_input_alias_map; + + bool is_elem_tuple = PyTuple_Check(alias_map); + bool is_elem_list = PyList_Check(alias_map); + + ORT_ENFORCE(is_elem_tuple || is_elem_list, "Input alias map must be a tuple or a list. is_elem_list: ", + is_elem_list, ", is_elem_tuple: ", is_elem_tuple); + Py_ssize_t output_count = is_elem_tuple ? PyTuple_Size(alias_map) : PyList_Size(alias_map); + for (Py_ssize_t output_index = 0; output_index < output_count; ++output_index) { + PyObject* input_index = + is_elem_tuple ? PyTuple_GetItem(alias_map, output_index) : PyList_GetItem(alias_map, output_index); + ORT_ENFORCE(PyLong_Check(input_index), "Alias input index must be an integer."); + int64_t alias_index_int = PyLong_AsLongLong(input_index); + output_to_input_alias_map.push_back(alias_index_int); + } + } +} + } // namespace onnxruntime::language_interop_ops::torch diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.h b/orttraining/orttraining/core/framework/torch/torch_proxy.h index aeb02bab97eea..1d5cc1dd69095 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.h +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.h @@ -2,8 +2,11 @@ // Licensed under the MIT License. #pragma once + #include #include +#include +#include #include "orttraining/core/framework/torch/python_common.h" #ifndef SHARED_PROVIDER @@ -61,6 +64,34 @@ class TorchProxy { const std::string& invoke_id, std::vector& return_args); + /** + * @brief Run given function to get output to input reuse map. + * + * @param input_alias_func Python function to run. + * The function should take a serialized PythonOp NodeProto string as input, return a tuple of two lists. + * The signature of the function should be: + * def alias_input(node_proto_str: str): + * fw_alias_map = [1, -1, -1] + * bw_alias_map = [-1, 0] + * return fw_alias_map, bw_alias_map + * @param node_proto_str The serialized PythonOp NodeProto string. + * @param fw_output_to_input_alias_map Used as returned value, return the output to input alias map for forward pass. + * For example, if the inputs of the torch.autograd.Function are [non_tensor_a, tensor_b], + * outputs are [tensor_x, tensor_y, tensor_z], and the alias map is [1, -1, -1], this is explained as: + * tensor_x is reusing the input tensor_b, tensor_y and tensor_z are not reusing any input. + * The value of alias map is 0 based input index. -1 means the output is not reusing any input. + * @param bw_output_to_input_alias_map Used as returned value, return the output to input alias map for backward pass. + * For example, if the inputs of the torch.autograd.Function are [tensor_x_grad, None, None], + * outputs are [None, tensor_b_grad], and the alias map is [-1, 0], this is explained as: + * tensor_b_grad is reusing the input tensor_x_grad. + * The value of alias map is 0 based grad input index. -1 means the output is not reusing any input. + */ + void RunInputAliasFunction( + void* input_alias_func, + const std::string& node_proto_str, + std::vector& fw_output_to_input_alias_map, + std::vector& bw_output_to_input_alias_map); + private: TorchProxy(){}; ~TorchProxy(){}; diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index b3da4f3977ff2..133cab71f2b1c 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1848,6 +1848,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) { "PythonOpGrad requiring gradient output count mismatch."); attrs.push_back(MakeAttribute("output_tensor_requires_grads", bw_tensor_output_requires_grads)); + // Copy bw_tensor_reuse_map attribute from PythonOp to PythonOpGrad if it is present. + auto attr_it = src_attrs.find("bw_tensor_reuse_map"); + if (attr_it != src_attrs.end()) { + std::vector tensor_output_to_tensor_input_reuse_map(attr_it->second.ints().begin(), + attr_it->second.ints().end()); + attrs.push_back(MakeAttribute("tensor_reuse_map", tensor_output_to_tensor_input_reuse_map)); + } + if (src_attrs.find("comment") != src_attrs.end() && utils::HasString(src_attrs.at("comment"))) { attrs.push_back(MakeAttribute("comment", src_attrs.at("comment").s())); } diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 5cd29303c3639..cfc79455c43ed 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3918,6 +3918,17 @@ Return true if all elements are true and false otherwise. "- the output 2 reuses the input 0.", AttributeProto::INTS, false) + .Attr( + "bw_tensor_reuse_map", + "Used for backward op only." + "A int array indicating whether output at each index is reusing specific input or now." + "If the given index is -1, it means the output is not reusing any input." + "For example, there are 3 inputs (including ctx) and 2 outputs, tensor_reuse_map = [2, 1] means" + "- the output 0 reuses the input 2." + "- the output 1 reuses the input 1." + "Be noted: the input 0 is ctx.", + AttributeProto::INTS, + false) .Attr( "training_mode", "Indicate if the model is exported in training_mode, by default, False.", diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 6b566ed064aa4..e5c65b2a96d8c 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -68,6 +68,9 @@ #include "core/optimizer/pre_shape_node_elimination.h" #include "orttraining/core/optimizer/compute_optimizer/padding_elimination.h" #include "orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h" +#ifdef ENABLE_TRAINING_TORCH_INTEROP +#include "orttraining/core/optimizer/pythonop_rewriter.h" +#endif namespace onnxruntime { namespace training { @@ -106,6 +109,9 @@ std::vector> GeneratePreTrainingTransformers( ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); +#ifdef ENABLE_TRAINING_TORCH_INTEROP + ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique())); +#endif // Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by diff --git a/orttraining/orttraining/core/optimizer/pythonop_rewriter.cc b/orttraining/orttraining/core/optimizer/pythonop_rewriter.cc new file mode 100644 index 0000000000000..e1cd71958bed1 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/pythonop_rewriter.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_TORCH_INTEROP + +#include +#include +#include + +#include "orttraining/core/optimizer/pythonop_rewriter.h" + +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" +#include "orttraining/core/framework/torch/torch_proxy.h" +#include "orttraining/core/framework/torch/custom_function_register.h" + +namespace onnxruntime { + +Status PythonOpRewriter::Apply(Graph&, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + bool modified = false; + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "PythonOp", {1}, kMSDomain) && + node.GetAttributes().find("tensor_reuse_map") == node.GetAttributes().end()) { + auto func_name = static_cast(node.GetAttributes().at("func_name").s()); + std::optional input_alias_function = + language_interop_ops::torch::OrtTorchFunctionPool::GetInstance().TryGettingInputAliasFunction(func_name); + if (input_alias_function.has_value()) { + // Serialize node proto to string + ONNX_NAMESPACE::NodeProto node_proto; + node.ToProto(node_proto); + std::string node_proto_str; + node_proto.SerializeToString(&node_proto_str); + + // Call input alias function + std::vector fw_all_output_to_tensor_input_reuse_map; + std::vector bw_all_output_to_tensor_input_reuse_map; + language_interop_ops::torch::TorchProxy::GetInstance().RunInputAliasFunction( + static_cast(input_alias_function.value()), + node_proto_str, + fw_all_output_to_tensor_input_reuse_map, + bw_all_output_to_tensor_input_reuse_map); + + auto input_convention = static_cast(node.GetAttributes().at("input_convention").s()); + { + // Handle forward input alias map. + std::vector fw_tensor_output_to_tensor_input_reuse_map = + std::vector((node.OutputDefs().size()), -1); + + // Map input index from `global` input index to `tensor` input index, because node.InputDefs() only contains + // tensor inputs. + std::unordered_map position_to_tensor_index; + int64_t tensor_index = 0; + const size_t all_input_count = input_convention.size(); + position_to_tensor_index.reserve(all_input_count); + for (size_t i = 0; i < all_input_count; ++i) { + if (input_convention[i] == 'd') { + position_to_tensor_index[i] = tensor_index; + ++tensor_index; + } + } + + for (size_t i = 1; i < fw_tensor_output_to_tensor_input_reuse_map.size(); ++i) { + if (fw_all_output_to_tensor_input_reuse_map[i - 1] != -1) { + ORT_ENFORCE(fw_all_output_to_tensor_input_reuse_map[i - 1] < static_cast(all_input_count), + "PythonOp input alias function output index out of range. func_name: ", func_name, " ", + fw_all_output_to_tensor_input_reuse_map[i - 1], " >= ", all_input_count); + fw_tensor_output_to_tensor_input_reuse_map[i] = + position_to_tensor_index.at(fw_all_output_to_tensor_input_reuse_map[i - 1]); + } + } + + node.AddAttribute("tensor_reuse_map", fw_tensor_output_to_tensor_input_reuse_map); + } + + { + // Handle backward input alias map. + auto& output_convention = input_convention; + ORT_ENFORCE(bw_all_output_to_tensor_input_reuse_map.size() == output_convention.size(), + "PythonOpGrad input alias function output count mismatch. func_name: ", func_name, " ", + bw_all_output_to_tensor_input_reuse_map.size(), " != ", output_convention.size()); + + std::vector bw_tensor_output_to_tensor_input_reuse_map = + std::vector(node.InputDefs().size(), -1); + size_t tensor_output_index = 0; + for (size_t i = 0; i < output_convention.size(); ++i) { + if (output_convention[i] == 'd') { + ORT_ENFORCE(tensor_output_index < bw_tensor_output_to_tensor_input_reuse_map.size(), + "PythonOpGrad input alias function output count mismatch. func_name: ", func_name, " ", + tensor_output_index, " >= ", bw_tensor_output_to_tensor_input_reuse_map.size()); + // input index shift by 1 to skip the context + bw_tensor_output_to_tensor_input_reuse_map[tensor_output_index] = + bw_all_output_to_tensor_input_reuse_map[i] == -1 ? -1 : bw_all_output_to_tensor_input_reuse_map[i] + 1; + ++tensor_output_index; + } + } + node.AddAttribute("bw_tensor_reuse_map", bw_tensor_output_to_tensor_input_reuse_map); + } + + modified = true; + } + } + + if (modified) + rule_effect = RewriteRuleEffect::kUpdatedCurrentNode; + + return Status::OK(); +} + +bool PythonOpRewriter::SatisfyCondition(const Graph&, const Node&, const logging::Logger&) const { + return true; +} + +} // namespace onnxruntime + +#endif diff --git a/orttraining/orttraining/core/optimizer/pythonop_rewriter.h b/orttraining/orttraining/core/optimizer/pythonop_rewriter.h new file mode 100644 index 0000000000000..5534b190979f0 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/pythonop_rewriter.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef ENABLE_TRAINING_TORCH_INTEROP + +#pragma once + +#include +#include +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +This transformer is to add schema supplementary for PythonOp. + +Currently, add memory reuse output to input map as an attribute, if users registered alias input function +in `OrtTorchFunctionPool`. +*/ + +class PythonOpRewriter : public RewriteRule { + public: + PythonOpRewriter() noexcept : RewriteRule("PythonOpRewriter") {} + + std::vector TargetOpTypes() const noexcept override { + return {"PythonOp"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime +#endif diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 35d9755ba0ba7..a08e8bee99cee 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -533,12 +533,44 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_UNUSED_PARAMETER(obj); #endif }); - m.def("register_torch_autograd_function", [](std::string key, py::object obj) -> void { + m.def("register_torch_autograd_function", [](std::string function_full_qual_name, py::object obj) -> void { #ifdef ENABLE_TRAINING_TORCH_INTEROP auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); - pool.RegisterTorchAutogradFunction(key, obj.ptr()); + pool.RegisterTorchAutogradFunction(function_full_qual_name, obj.ptr()); #else - ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(function_full_qual_name); + ORT_UNUSED_PARAMETER(obj); +#endif + }); + m.def("register_shape_inference_function", [](std::string function_full_qual_name, py::object obj) -> void { +#ifdef ENABLE_TRAINING_TORCH_INTEROP + auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); + pool.RegisterShapeInferenceFunction(function_full_qual_name, obj.ptr()); +#else + ORT_UNUSED_PARAMETER(function_full_qual_name); + ORT_UNUSED_PARAMETER(obj); +#endif + }); + m.def("get_shape_inference_function", [](std::string function_full_qual_name) -> py::object { +#ifdef ENABLE_TRAINING_TORCH_INTEROP + auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); + auto py_object = pool.TryGettingShapeInferenceFunction(function_full_qual_name); + if (py_object.has_value()) { + Py_INCREF(py_object.value()); + return py::reinterpret_steal(py_object.value()); + } +#else + ORT_UNUSED_PARAMETER(function_full_qual_name); +#endif + return py::none(); + }); + + m.def("register_input_alias_function", [](std::string function_full_qual_name, py::object obj) -> void { +#ifdef ENABLE_TRAINING_TORCH_INTEROP + auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); + pool.RegisterInputAliasFunction(function_full_qual_name, obj.ptr()); +#else + ORT_UNUSED_PARAMETER(function_full_qual_name); ORT_UNUSED_PARAMETER(obj); #endif }); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index c6edaf7cd3a2c..8c5469740d9bd 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import sys -from typing import Callable, ClassVar, Dict, Optional import torch import torch.utils.checkpoint @@ -12,7 +11,12 @@ from packaging import version from torch.onnx import symbolic_helper -from onnxruntime.capi._pybind_state import register_miscellaneous_const_input, register_torch_autograd_function +from onnxruntime.capi._pybind_state import ( + register_input_alias_function, + register_miscellaneous_const_input, + register_shape_inference_function, + register_torch_autograd_function, +) from onnxruntime.training import ortmodule from onnxruntime.training.utils import pytorch_dtype_to_onnx @@ -21,45 +25,46 @@ from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version -class PythonOpShapeInferStore: - """A class to store shape inference functions for torch.autograd.Function.""" +def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None: + """Register a shape inference function for a torch.autograd.Function if there is staticmethod + "infer_shape" defined. - _CLASS_MAP: ClassVar[Dict[str, Callable]] = {} + The signature of the shape inference function should be: + @staticmethod + def infer_shape( + node: onnx.NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + tensor_output_shapes = [] + tensor_output_dtypes = [] + ... + return tensor_output_shapes, tensor_output_dtypes - @classmethod - def register(cls, kclass: torch.autograd.Function) -> None: - """Register a shape inference function for a torch.autograd.Function if there is staticmethod - "infer_shape" defined. + The tensor_input_shapes and tensor_input_dtypes are lists of shapes and dtypes of the input tensors. + The tensor_output_shapes and tensor_output_dtypes are lists of shapes and dtypes of the output tensors. + Be noted: we only pass in tensor inputs, and return tensor outputs, non-tensor inputs/outputs are ignored. - The signature of the shape inference function should be: - @staticmethod - def infer_shape( - node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: - tensor_output_shapes = [] - tensor_output_dtypes = [] - ... - return tensor_output_shapes, tensor_output_dtypes - The tensor_input_shapes and tensor_input_dtypes are lists of shapes and dtypes of the input tensors. - The tensor_output_shapes and tensor_output_dtypes are lists of shapes and dtypes of the output tensors. - Be noted: we only pass in tensor inputs, and return tensor outputs, non-tensor inputs/outputs are ignored. + The signature of the alias input function should be: + @staticmethod + def alias_input(node_proto_str: str) -> Tuple[List[int], List[int]]: + fw_alias_map = [1, -1, -1] + bw_alias_map = [-1, 0] + return fw_alias_map, bw_alias_map - """ - kclass_name = get_fully_qualified_class_name(kclass) - if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP: - cls._CLASS_MAP[kclass_name] = kclass.infer_shape + The alias input function should return a tuple of two lists: + - The first list is the forward alias map, its length is equal to the number of all outputs of the node. + - The second list is the backward alias map, its length is equal to the number of all inputs + (tensor and non-tensor) of the node. - @classmethod - def register_func(cls, name: str, func: Callable) -> None: - """Register a shape inference function for a torch.autograd.Function by name.""" - cls._CLASS_MAP[name] = func + """ + kclass_name = get_fully_qualified_class_name(kclass) + if hasattr(kclass, "infer_shape"): + register_shape_inference_function(kclass_name, kclass.infer_shape) - @classmethod - def get_shape_infer(cls, name: str) -> Optional[Callable]: - return cls._CLASS_MAP.get(name, None) + if hasattr(kclass, "alias_input"): + register_input_alias_function(kclass_name, kclass.alias_input) """ @@ -299,8 +304,8 @@ def _export_pt_1_10(g, n, *args, **kwargs): # Register function with class names. register_torch_autograd_function(func_full_qual_name, func_class) - # Register shape inference function. - PythonOpShapeInferStore.register(func_class) + register_custom_function_schema_supplementary(func_class) + return returned_args except Exception as e: sys.stdout.flush() @@ -327,7 +332,7 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model op_name_prefix = kclass_name break - node.name = f"{op_name_prefix}_id_{index}" + node.name = f"{op_name_prefix}_id_{index}" index += 1 return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index fd791f21b4d22..b9318033a3d53 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -76,10 +76,10 @@ def __init__(self, kernel_invoke_id: str): def _process_inplace_outputs( kernel_info: CustomFuncOpKernelInfo, func_name: str, - input_tensors_of_kernel_run: List[torch.Tensor], + input_tensors_of_kernel_run: Dict[int, Union[torch.Tensor, None]], all_outputs_of_kernel_run: List[Union[torch.Tensor, any]], all_outputs_to_tensor_inputs_reuse_map: List[int], - raw_input_tensors_used_inplace: Dict[int, torch.Tensor], + raw_input_tensors_used_inplace: Dict[int, Union[torch.Tensor, None]], is_backward=False, ): """Special handling for in-place reusing in forward or backward. @@ -87,12 +87,12 @@ def _process_inplace_outputs( Args: kernel_info: kernel-specific information. func_name: name of the autograd.Function. - input_tensors_of_kernel_run: input tensors used to run the autograd.Function forward/backward. + input_tensors_of_kernel_run: all tensor input tensors used to run the autograd.Function forward/backward. all_outputs_of_kernel_run: all outputs of the autograd.Function forward/backward. all_outputs_to_tensor_inputs_reuse_map: a list of the same length of kernel outputs, each element representing which input index it is reusing. If there is no reuse, the value is -1. raw_input_tensors_used_inplace: a dict of raw input tensors marked as inplace in - `all_outputs_to_tensor_inputs_reuse_map`, the key is the input index, value is the raw input tensor. + `all_outputs_to_tensor_inputs_reuse_map`, the key is the tensor input index, value is the raw input tensor. is_backward: indicates if this is backward or forward. Procedures: @@ -127,7 +127,9 @@ def _process_inplace_outputs( """ log_prefix = f"{func_name}->{'Backward' if is_backward else 'Forward'}: " - input_tensor_address_list = [t.data_ptr() for t in input_tensors_of_kernel_run] + input_tensor_address_list = [ + t.data_ptr() if isinstance(t, torch.Tensor) else -1 for t in input_tensors_of_kernel_run.values() + ] if is_backward: input_tensor_address_list = [-1, *input_tensor_address_list] # skip the context input @@ -161,6 +163,14 @@ def _process_inplace_outputs( if inplace_index == detected_inplace_index: continue + if ( + inplace_index in raw_input_tensors_used_inplace + and raw_input_tensors_used_inplace[inplace_index] is None + ): + # Use specified inplace input index, but the input tensor is None, which means the input is not + # a tensor, so we don't do further checks. + continue + # If users register inplace_map (alloc planner will do buffer reuse), # but detected inplace_map indicates it is NO inplace reusing, we raise an error. if inplace_index != -1 and detected_inplace_index == -1: @@ -210,7 +220,8 @@ def _process_inplace_outputs( ): for raw_tensor_input_index, raw_input_tensor in raw_input_tensors_used_inplace.items(): # raw_input_tensor can be None for backward run, but backward won't go here. - assert isinstance(raw_input_tensor, torch.Tensor) + if not isinstance(raw_input_tensor, torch.Tensor): + continue # We did not do the check with tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty # because even for those tensor indices not in @@ -236,8 +247,8 @@ def _process_inplace_outputs( # Only need a copy once. raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index]) _log_warning( - f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}." - "Provide output to input reuse mapping to avoid the copy overhead." + f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. " + f"{'Provide output to input reuse mapping to avoid the copy overhead.' if not is_first_time_init else ''}" ) copied = True @@ -531,7 +542,7 @@ def call_python_forward_function( _process_inplace_outputs( kernel_info, func_name, - input_tensors_used_for_fw_run.values(), + input_tensors_used_for_fw_run, final_rets, inplace_map, raw_input_tensors_used_inplace, @@ -624,6 +635,10 @@ def wrap_all_outputs(result): wrapped_arg = torch.zeros(shape, device=device, dtype=dtype) else: wrapped_arg = arg + + if grad_input_index in inplace_map: + raw_input_tensors_used_inplace[tensor_input_index] = arg + else: # Assume it's a DLPack tensor# and convert it to PyTorch tensor. wrapped_arg = from_dlpack(arg) @@ -631,7 +646,8 @@ def wrap_all_outputs(result): if grad_input_index in inplace_map: raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg - input_tensors_used_for_bw_run[tensor_input_index] = wrapped_arg + # This may include None values. + input_tensors_used_for_bw_run[tensor_input_index] = wrapped_arg if wrapped_arg is not None: # Only requires gradient when running under training mode @@ -662,7 +678,7 @@ def wrap_all_outputs(result): _process_inplace_outputs( kernel_info, func_name, - input_tensors_used_for_bw_run.values(), + input_tensors_used_for_bw_run, result, inplace_map, raw_input_tensors_used_inplace, diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index dfaac5f0fa836..5e8805bfddb4d 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -143,9 +143,6 @@ def __init__( self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: # Cannot toggle feature enabling/disabling after the first time enabled. - from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params - - self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module) configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) @@ -420,8 +417,11 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu exported_model = post_process_enabling_autograd_function(exported_model) if self._runtime_options.enable_zero_stage3_support: + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module) exported_model = post_processing_enable_zero_stage3_compat( exported_model, self._zero_stage3_param_map, diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index 301071f6de44c..d0dea66fda2a7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -9,10 +9,14 @@ import torch from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper -from onnxruntime.capi._pybind_state import register_torch_autograd_function +from onnxruntime.capi._pybind_state import ( + register_input_alias_function, + register_shape_inference_function, + register_torch_autograd_function, +) from onnxruntime.training.utils import pytorch_dtype_to_onnx -from ._custom_autograd_function_exporter import PythonOpShapeInferStore +from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary from ._utils import get_fully_qualified_class_name STAGE3_PULL_WEIGHT_TRIGGER_NAME = "pull_weight_trigger" @@ -36,6 +40,8 @@ def post_processing_enable_zero_stage3_compat( # Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3. _register_symbolic_shape_infer_functions() + _register_alias_input_functions() + # Create weight retrieving function using zero_stage3_named_params. func_full_qual_name = _create_weight_retrieval_function(zero_stage3_named_params) @@ -69,7 +75,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction - prefowrad_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) + pre_forward_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) # Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. for graph_input in exported_model.graph.input: @@ -87,7 +93,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: continue func_name = _get_func_name(c) - if func_name == prefowrad_function_name: + if func_name == pre_forward_function_name: assert ( pre_forward_pythonop_node is None ), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen" @@ -98,6 +104,7 @@ def _get_func_name(node: NodeProto) -> Optional[str]: "Fail to find ORTZeROOffloadPreForwardFunction for partitioned param: " + graph_input.name ) + pull_weight_trigger_input_name = _get_param_pull_trigger_name(graph_input.name) index_offset_on_python_op_input = [] for i, input_name in enumerate(pre_forward_pythonop_node.input): if input_name == graph_input.name: @@ -105,21 +112,32 @@ def _get_func_name(node: NodeProto) -> Optional[str]: assert ( len(index_offset_on_python_op_input) == 1 - ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input}" + ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input} for node {pre_forward_pythonop_node.name}, input {graph_input.name}, {pre_forward_pythonop_node.input}" reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input) - new_input_name = _get_param_pull_trigger_name(graph_input.name) - pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = new_input_name + + pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = pull_weight_trigger_input_name _update_python_op_input_related_attributes( pre_forward_pythonop_node, - new_input_name, + pull_weight_trigger_input_name, len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE), # new rank STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, # new data type ) output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output) - pre_forward_pythonop_node.output[output_index] = graph_input.name + + ready_weight_name = f"ready_{graph_input.name}" + pre_forward_pythonop_node.output[output_index] = ready_weight_name + + # Update consumer's input to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. + for c in consumers: + new_inputs = [c_input for c_input in c.input] + for c_input_index in range(len(c.input)): + if c.input[c_input_index] == graph_input.name: + new_inputs[c_input_index] = ready_weight_name + del c.input[:] + c.input.extend(new_inputs) # If the consumer of original `graph_input.name` is PythonOp, we need also update its attributes because now # `graph_input.name` as output of pre_forward_pythonop_node, is full-sized parameter, the rank might differ @@ -186,11 +204,13 @@ def infer_shape( tensor_output_dtypes = [ tensor_input_dtypes[0], ] * param_count + return tensor_output_shapes, tensor_output_dtypes func_full_qual_name = get_fully_qualified_class_name(WeightRetrievalFunction) register_torch_autograd_function(func_full_qual_name, WeightRetrievalFunction) - PythonOpShapeInferStore.register(WeightRetrievalFunction) + + register_custom_function_schema_supplementary(WeightRetrievalFunction) return func_full_qual_name @@ -206,10 +226,10 @@ def _simple_pass_through_infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes - PythonOpShapeInferStore.register_func( + register_shape_inference_function( "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape ) - PythonOpShapeInferStore.register_func( + register_shape_inference_function( "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape ) @@ -225,9 +245,36 @@ def _linear_infer_shape( output_shape[-1] = shape2[-2] return [output_shape], [tensor_input_dtypes[0]] - PythonOpShapeInferStore.register_func( - "deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape - ) + register_shape_inference_function("deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape) + + +def _register_alias_input_functions(): + """This function is used to register symbolic shape inference functions for PythonOp used in + DeepSpeed ZeRO stage3.""" + + def _alias_input(node_proto_str: str): + node: NodeProto = NodeProto() + node.ParseFromString(node_proto_str) + non_tensor_fw_input_count = 2 + + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + + return fw_alias_map, bw_alias_map + + register_input_alias_function("deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _alias_input) + register_input_alias_function("deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _alias_input) def _create_weight_retrieval_pythonop( @@ -276,7 +323,9 @@ def _create_weight_retrieval_pythonop( return new_input, weight_pull_node -def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, new_rank: int, new_dtype: int): +def _update_python_op_input_related_attributes( + node: NodeProto, input_name: str, new_rank: int, new_dtype: torch.onnx.TensorProtoDataType +): """This function is used to update PythonOp's input related attributes, e.g. input_tensor_ranks and input_tensor_types. @@ -284,7 +333,7 @@ def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, node (NodeProto): The PythonOp node. input_name (str): The input name to be updated. new_rank (int): The new rank of the input, to be used in input_tensor_ranks. - new_dtype (int): The new data type of the input, to be used in input_tensor_types. + new_dtype (torch.onnx.TensorProtoDataType): The new data type of the input, to be used in input_tensor_types. """ input_tensor_ranks = None input_tensor_dtypes = None @@ -304,7 +353,7 @@ def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, for index, node_input_name in enumerate(node.input): if node_input_name == input_name: input_tensor_ranks[index] = new_rank - input_tensor_dtypes[index] = new_dtype + input_tensor_dtypes[index] = int(new_dtype) node.attribute.remove(rank_attr) node.attribute.remove(dtype_attr) diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index db1c69cf95ba4..c5be17236ac06 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -93,6 +93,13 @@ def infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes + @staticmethod + def alias_input(node_proto_str: str): + fw_alias_map = [3] + bw_alias_map = [-1] * 6 + bw_alias_map[3] = 0 + return fw_alias_map, bw_alias_map + class StatisticsSubscriber(SubscriberBase): """ diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index b2bc64be42fc1..c9c06dabab4de 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -68,6 +68,26 @@ def infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes + @staticmethod + def alias_input(node_proto_str: str): + node = onnx.NodeProto() + node.ParseFromString(node_proto_str) + non_tensor_fw_input_count = 1 + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + return fw_alias_map, bw_alias_map + class SubscriberManager: """This class is used to manage all the subscribers and register subscribers' custom actions as PyTorch hooks diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index ad1297962db71..373633fa85ee2 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -272,12 +272,41 @@ def infer_shape( start_offset = len(tensor_input_shapes) - len(partitioned_params) for index, param in enumerate(partitioned_params): tensor_output_shapes[start_offset + index] = list(param.ds_shape) - tensor_output_dtypes[start_offset + index] = pytorch_dtype_to_onnx(param.dtype) + tensor_output_dtypes[start_offset + index] = int(pytorch_dtype_to_onnx(param.dtype)) assert len(tensor_output_shapes) == len(tensor_input_shapes) assert len(tensor_output_dtypes) == len(tensor_input_dtypes) return tensor_output_shapes, tensor_output_dtypes + @staticmethod + def alias_input(node_proto_str: str): + node = onnx.NodeProto() + node.ParseFromString(node_proto_str) + input_pointer_scalars_attr_name = "input_pointer_scalars" + found = [attr for attr in node.attribute if attr.name == input_pointer_scalars_attr_name] + assert len(found) == 1 + input_pointer_scalars = found[0].ints + # Restore the nn.Module from the pointer. + module = ctypes.cast(input_pointer_scalars[0], ctypes.py_object).value + partitioned_params = _get_params_for_current_module(module) + + non_tensor_fw_input_count = 6 + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count - len(partitioned_params)): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map) - len(partitioned_params)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + + return fw_alias_map, bw_alias_map + class ORTZeROOffloadPostForwardFunction(torch.autograd.Function): @staticmethod @@ -332,6 +361,27 @@ def infer_shape( ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: return tensor_input_shapes, tensor_input_dtypes + @staticmethod + def alias_input(node_proto_str: str): + node = onnx.NodeProto() + node.ParseFromString(node_proto_str) + non_tensor_fw_input_count = 4 + fw_output_count = len(node.output) - 1 # exclude the first output appended in ONNX + fw_alias_map = [-1] * fw_output_count + bw_alias_map = [-1] * (non_tensor_fw_input_count + len(node.input)) + + for i in range(fw_output_count): + fw_alias_map[i] = i + non_tensor_fw_input_count + + tensor_input_index = 0 + for i in range(len(bw_alias_map)): + if i < non_tensor_fw_input_count: + continue + bw_alias_map[i] = tensor_input_index + tensor_input_index += 1 + + return fw_alias_map, bw_alias_map + class _ZeROOffloadFunctions: def __init__(self, one_time_init: _ZeROOffloadOneTimeInitializer, offloader) -> None: diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index a31fa5d850e59..41f4a41a7c38a 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -294,7 +294,8 @@ void PythonOpBase::SetContextOutput(OpKernelContext* context, void* diff_ctx) co void PythonOpBase::SetOtherOutputs(OpKernelContext* context, std::vector& returned_ortvalues) const { auto* ctx_internal = reinterpret_cast(context); - ORT_ENFORCE(returned_ortvalues.size() == all_output_to_tensor_input_reuse_map_.size() - 1, "PythonOp output count mismatch inplace map count.", + ORT_ENFORCE(returned_ortvalues.size() == all_output_to_tensor_input_reuse_map_.size() - 1, + "PythonOp output count mismatch inplace map count.", returned_ortvalues.size(), " != ", all_output_to_tensor_input_reuse_map_.size() - 1); for (size_t i = 0; i < returned_ortvalues.size(); ++i) { size_t output_index = i + 1;