From b36efed6b7eb4b39197158a6297af386bb843831 Mon Sep 17 00:00:00 2001 From: pengwa Date: Sat, 7 Oct 2023 08:40:19 +0800 Subject: [PATCH] Fix convergence for dolly+stage3 training (#17685) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Fix convergence for dolly+stage3 training In [ZeROOffloadSubscriber](https://github.com/microsoft/onnxruntime/blob/216214b7d302cb504d1e5a647f65b6fe49c22dbb/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py#L359C7-L359C28), we defined some PythonOp, taking input and returning it inplace, for example: https://github.com/microsoft/onnxruntime/blob/216214b7d302cb504d1e5a647f65b6fe49c22dbb/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py#L223C20-L223C20. While it is possible, when ORT runs such a PythonOp, once it completes, it will release the input OrtValue, triggered the data erasing or overridden. But the PythonOp's returned value OrtValue are still pointing to that address, reading or writting on that may introduce a wrong result or even undefined behaviors. ``` /bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_custom_autograd_function_runner.py:28: UserWarning: .rank-0: onnxruntime.training.utils.hooks._zero_offload_subscriber.ORTZeROOffloadPreForwardFunction->Backward: ONNX Op attribute 'tensor_reuse_map' doesn't indicate 8-th output is reusing any input, but detected inplace_map indicates it is reusing some input index. A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. Please update inplace_map explicitly to avoid such a copy. warnings.warn(f".rank-{get_rank()}: {message}") 0%|▏ | 1/1000 [00:04<1:15:08, 4.51s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,023 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 14.1406, 'learning_rate': 0, 'epoch': 0.0} 0%|▏ | 1/1000 [00:04<1:15:08, 4.51s/it]Invalidate trace cache @ step 5: expected module 6, but got module 7 0%|▍ | 2/1000 [00:04<31:53, 1.92s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,124 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|▋ | 3/1000 [00:04<18:05, 1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,227 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|▋ | 3/1000 [00:04<18:05, 1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,326 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|█▏ | 5/1000 [00:04<08:44, 1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,419 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 0%|█▏ | 5/1000 [00:04<08:44, 1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,505 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|█▋ | 7/1000 [00:05<05:28, 3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,597 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|█▋ | 7/1000 [00:05<05:28, 3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,690 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▏ | 9/1000 [00:05<03:57, 4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,791 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▏ | 9/1000 [00:05<03:57, 4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,889 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▋ | 11/1000 [00:05<03:06, 5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,981 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0} 1%|██▋ | 11/1000 [00:05<03:06, 5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,073 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 1%|███▏ | 13/1000 [00:05<02:33, 6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,166 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 1%|███▏ | 13/1000 [00:05<02:33, 6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,256 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|███▌ | 15/1000 [00:05<02:12, 7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,348 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|███▌ | 15/1000 [00:05<02:12, 7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,439 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|████ | 17/1000 [00:06<01:59, 8.22it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,535 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0 {'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01} 2%|████ | 17/1000 [00:06<01:59, 8.22it/s]Traceback (most recent call last): File "examples/onnxruntime/training/language-modeling/run_clm.py", line 600, in main() File "examples/onnxruntime/training/language-modeling/run_clm.py", line 548, in main train_result = trainer.train(resume_from_checkpoint=checkpoint) File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 457, in train return inner_training_loop( File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 781, in _inner_training_loop self.deepspeed.step() File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 2084, in step self._take_model_step(lr_kwargs) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 1990, in _take_model_step self.optimizer.step() File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1854, in step if self._overflow_check_and_loss_scale_update(): File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1788, in _overflow_check_and_loss_scale_update self._update_scale(self.overflow) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 2132, in _update_scale self.loss_scaler.update_scale(has_overflow) File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/fp16/loss_scaler.py", line 175, in update_scale raise Exception( Exception: Current loss scale already at minimum - cannot decrease scale anymore. Exiting run. 2%|████ | 17/1000 [00:06<06:07, 2.67it/s] [2023-09-25 08:30:51,075] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1065120) of binary: /bert_ort/pengwa/py38/bin/python Traceback (most recent call last): File "/bert_ort/pengwa/py38/bin/torchrun", line 8, in sys.exit(main()) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper return f(*args, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 806, in main run(args) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 797, in run elastic_launch( File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ examples/onnxruntime/training/language-modeling/run_clm.py FAILED ------------------------------------------------------------ Failures: ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2023-09-25_08:30:51 host : orttrainingdev10.internal.cloudapp.net rank : 0 (local_rank: 0) exitcode : 1 (pid: 1065120) error_file: traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================ (/bert_ort/pengwa/py38) pengwa@microsoft.com@orttrainingdev10:/bert_ort/pengwa/optim ``` ## The Fix For those output that are reusing input, but ORT is not aware of, we detected on the fly (the first iteration, by checking the output tensor addresses with input tensor addresses) , then do implicit copy before set it as PythonOp's output tensors. With this fix: (left: PyTorch, right: ORT) ![image](https://github.com/microsoft/onnxruntime/assets/10530022/0d72f431-2abd-4e52-af99-19974b85edde) --- .../core/framework/torch/torch_proxy.cc | 85 ++-- .../core/framework/torch/torch_proxy.h | 16 +- .../core/graph/gradient_builder.cc | 1 - .../core/graph/training_op_defs.cc | 29 +- .../_custom_autograd_function_exporter.py | 2 - .../_custom_autograd_function_runner.py | 441 ++++++++++++++---- .../ortmodule/_zero_stage3_compatibility.py | 16 +- .../torch_interop_utils.cc | 29 ++ .../orttraining_test_ortmodule_autograd.py | 96 +++- .../torch_custom_function_kernel_base.cc | 90 +++- .../torch/torch_custom_function_kernel_base.h | 7 +- 11 files changed, 657 insertions(+), 155 deletions(-) diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.cc b/orttraining/orttraining/core/framework/torch/torch_proxy.cc index 377f564a00337..58e22f4e266ee 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.cc +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.cc @@ -10,9 +10,7 @@ #include "orttraining/core/framework/torch/gil.h" #include "core/platform/env.h" -namespace onnxruntime { -namespace language_interop_ops { -namespace torch { +namespace onnxruntime::language_interop_ops::torch { void PythonObjectDeleter(PyObject* ptr) { Py_XDECREF(ptr); }; @@ -130,6 +128,18 @@ PyObject* CreateRequiresGradFlags( return flags; } +PyObject* CreateInplaceMap( + const std::vector& inplace_map) { + PyObject* inplace_map_obj = Ort_PyList_New(inplace_map.size(), "inplace_map"); + + for (size_t output_index = 0; output_index < inplace_map.size(); ++output_index) { + PyObject* input_index = PyLong_FromLong(inplace_map[output_index]); + Ort_PyList_SetItem_NoIncref(inplace_map_obj, output_index, input_index, std::to_string(__LINE__)); + } + + return inplace_map_obj; +} + void InvokeRunner( PyObject* callback_runner, PyObject* args, @@ -197,14 +207,15 @@ PythonObjectPtr CreatePythonCallArguments( const std::vector& obj_args, const std::vector& obj_indices, const bool is_training_mode, - const bool is_inplace, - const std::string& invoke_id) { + const std::vector& inplace_map, + const std::string& invoke_id, + const std::string& func_name) { ORT_ENFORCE(PyCallable_Check(callback), "Forward callback is not callable."); // The number of variables before those of // autograd.Function.apply and autograd.Function.backward. // The extra variables are used to configure the launch // forward and backward runners. - constexpr int64_t num_control_args = 6; + constexpr int64_t num_control_args = 7; // All arguments created for Python call will be destroyed along with PythonObjectPtr. PythonObjectPtr args(Ort_PyTuple_New(num_control_args + len, "forward_arguments_tuple"), PythonObjectDeleter); @@ -216,11 +227,16 @@ PythonObjectPtr CreatePythonCallArguments( Ort_PyTuple_SetItem_NoIncref(args.get(), 2, tensor_flags, "tensor_flags"); PyObject* is_training_mode_arg = is_training_mode ? Py_True : Py_False; Ort_PyTuple_SetItem_Incref(args.get(), 3, is_training_mode_arg, "is_training_mode"); - PyObject* is_inplace_arg = is_inplace ? Py_True : Py_False; - Ort_PyTuple_SetItem_Incref(args.get(), 4, is_inplace_arg, "is_inplace_mode"); + + PyObject* inplace_map_arg = CreateInplaceMap(inplace_map); + Ort_PyTuple_SetItem_NoIncref(args.get(), 4, inplace_map_arg, "inplace_map"); + PyObject* kernel_invoke_id_arg = PyBytes_FromStringAndSize(invoke_id.c_str(), invoke_id.size()); Ort_PyTuple_SetItem_NoIncref(args.get(), 5, kernel_invoke_id_arg, "kernel_invoke_id_arg"); + PyObject* func_name_arg = PyBytes_FromStringAndSize(func_name.c_str(), func_name.size()); + Ort_PyTuple_SetItem_NoIncref(args.get(), 6, func_name_arg, "func_name_arg"); + // Tensor inputs to call autograd.Function.apply or autograd.Function.backward. for (size_t i = 0; i < tensor_args.size(); ++i) { if (!tensor_args[i].has_value()) { @@ -246,6 +262,7 @@ PythonObjectPtr CreatePythonCallArguments( } void Invoke( + const std::string& func_name, PyObject* runner, PyObject* callback, const std::vector& requires_grads, @@ -253,11 +270,11 @@ void Invoke( const std::vector& tensor_indices, const std::vector& obj_args, const std::vector& obj_indices, - void** diff_ctx, - std::vector& returned_ortvalues, const bool is_training_mode, - const bool is_inplace, - const std::string& invoke_id) { + const std::vector& inplace_map, + const std::string& invoke_id, + void** diff_ctx, + std::vector& returned_ortvalues) { const auto len = tensor_args.size() + obj_args.size(); CheckArguments(len, requires_grads, tensor_args, tensor_indices, obj_args, obj_indices); RefCountTracker::GetInstance().Reset(); @@ -271,8 +288,9 @@ void Invoke( obj_args, obj_indices, is_training_mode, - is_inplace, - invoke_id); + inplace_map, + invoke_id, + func_name); RefCountTracker::GetInstance().DumpDetails("Before Invoke Python Call"); InvokeRunner(runner, args.get(), is_training_mode, diff_ctx, returned_ortvalues); @@ -282,17 +300,18 @@ void Invoke( } void TorchProxy::Forward( + const std::string& func_name, void* callback, const std::vector& requires_grads, const std::vector>& tensor_args, const std::vector& tensor_indices, const std::vector& obj_args, const std::vector& obj_indices, - void** diff_ctx, - std::vector& returned_ortvalues, const bool is_training_mode, - const bool is_inplace, - const std::string& invoke_id) { + const std::vector& inplace_map, + const std::string& invoke_id, + void** diff_ctx, + std::vector& returned_ortvalues) { // Semantically, this lock uniquely takes the ownership of TorchProxy // so that there will be only one of TorchProxy::Forward TorchProxy::Backward // can be run at one time. @@ -301,6 +320,7 @@ void TorchProxy::Forward( GilGuard guard; auto runner = OrtTorchFunctionPool::GetInstance().GetForwardRunner(); Invoke( + func_name, runner, reinterpret_cast(callback), requires_grads, @@ -308,22 +328,23 @@ void TorchProxy::Forward( tensor_indices, obj_args, obj_indices, - diff_ctx, - returned_ortvalues, is_training_mode, - is_inplace, - invoke_id); + inplace_map, + invoke_id, + diff_ctx, + returned_ortvalues); } void TorchProxy::Backward( + const std::string& func_name, void* callback, const std::vector>& tensor_args, const std::vector& tensor_indices, const std::vector& obj_args, const std::vector& obj_indices, - std::vector& returned_ortvalues, - const bool is_inplace, - const std::string& invoke_id) { + const std::vector& inplace_map, + const std::string& invoke_id, + std::vector& returned_ortvalues) { // Semantically, this lock uniquely takes the ownership of TorchProxy // so that there will be only one of TorchProxy::Forward TorchProxy::Backward // can be run at one time. @@ -336,6 +357,7 @@ void TorchProxy::Backward( const auto all_input_count = tensor_args.size() + obj_args.size(); const std::vector requires_grads(all_input_count, 0); Invoke( + func_name, runner, reinterpret_cast(callback), requires_grads, @@ -343,12 +365,11 @@ void TorchProxy::Backward( tensor_indices, obj_args, obj_indices, - nullptr /* context to store */, - returned_ortvalues, true /* is_training_mode */, - is_inplace, - invoke_id); + inplace_map, + invoke_id, + nullptr /* context to store */, + returned_ortvalues); } -} // namespace torch -} // namespace language_interop_ops -} // namespace onnxruntime + +} // namespace onnxruntime::language_interop_ops::torch diff --git a/orttraining/orttraining/core/framework/torch/torch_proxy.h b/orttraining/orttraining/core/framework/torch/torch_proxy.h index 189efc772a62c..aeb02bab97eea 100644 --- a/orttraining/orttraining/core/framework/torch/torch_proxy.h +++ b/orttraining/orttraining/core/framework/torch/torch_proxy.h @@ -37,27 +37,29 @@ class TorchProxy { }; void Forward( + const std::string& func_name, void* callback, const std::vector& requires_grads, const std::vector>& tensor_args, const std::vector& tensor_indices, const std::vector& obj_args, const std::vector& obj_indices, - void** diff_ctx, - std::vector& returned_ortvalues, const bool is_training_mode, - const bool is_inplace, - const std::string& invoke_id); + const std::vector& inplace_map, + const std::string& invoke_id, + void** diff_ctx, + std::vector& returned_ortvalues); void Backward( + const std::string& func_name, void* callback, const std::vector>& tensor_args, const std::vector& tensor_indices, const std::vector& obj_args, const std::vector& obj_indices, - std::vector& return_args, - const bool is_inplace, - const std::string& invoke_id); + const std::vector& inplace_map, + const std::string& invoke_id, + std::vector& return_args); private: TorchProxy(){}; diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index a14f849958fa7..b3da4f3977ff2 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1765,7 +1765,6 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) { ORT_ENFORCE(utils::HasString(src_attrs.at("func_name"))); attrs.push_back(MakeAttribute("func_name", src_attrs.at("func_name").s())); attrs.push_back(MakeAttribute("output_convention", src_attrs.at("input_convention").s())); - attrs.push_back(MakeAttribute("inplace", src_attrs.at("inplace").i())); // input_tensor_types[i] store the type of autograd.Function.apply's ith output. // Note that PythonOpGrad's 0-th input is the Python context generated by PythonOp. diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 86d3cdee9ba98..5cd29303c3639 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3908,10 +3908,16 @@ Return true if all elements are true and false otherwise. AttributeProto::INTS) // Other attributes. .Attr( - "inplace", - "Indicate if the output should reuse input memory.", - AttributeProto::INT, - static_cast(0)) + "tensor_reuse_map", + "A int array indicating whether output at each index is reusing specific input or not." + "If the given index is -1, it means the output is not reusing any input." + "For example, there are 2 tensor inputs and 3 tensor outputs (including ctx), " + "tensor_reuse_map = [-1, 1, 0] means" + "- the output 0 (ctx) don't reuse any input buffer." + "- the output 1 reuses the input 1." + "- the output 2 reuses the input 0.", + AttributeProto::INTS, + false) .Attr( "training_mode", "Indicate if the model is exported in training_mode, by default, False.", @@ -4033,11 +4039,6 @@ Return true if all elements are true and false otherwise. "func_name", "Name of custom class.", AttributeProto::STRING) - .Attr( - "inplace", - "Indicate if the output should reuse input memory. Todo(pengwa): do we need it?", - AttributeProto::INT, - static_cast(0)) .Attr( "input_tensor_types", "Input types of autograd.Function.backward (including only tensor inputs)." @@ -4069,6 +4070,16 @@ Return true if all elements are true and false otherwise. "A string inidicating autograd.Function.backward outputs's type." "value 'c' - non-tensor output; value 'd' - tensor output.", AttributeProto::STRING) + .Attr( + "tensor_reuse_map", + "A int array indicating whether output at each index is reusing specific input or not." + "If the given index is -1, it means the output is not reusing any input." + "For example, there are 3 inputs (including ctx) and 2 outputs, tensor_reuse_map = [2, 1] means" + "- the output 0 reuses the input 2." + "- the output 1 reuses the input 1." + "Be noted: the input 0 is ctx.", + AttributeProto::INTS, + false) .Attr( "comment", "comment only for debugging purposes.", diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index f75d553a5f460..c6edaf7cd3a2c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -118,7 +118,6 @@ def _export_pt_1_10(g, n, *args, **kwargs): "wrap exportable sub-nn.Module's as ORTModule." ) - inplace = kwargs["inplace"] # TODO move to public API once the exporter team exposes that training_mode = None if get_runtime_pytorch_version() >= version.parse("1.12"): @@ -260,7 +259,6 @@ def _export_pt_1_10(g, n, *args, **kwargs): attrs = { "func_name_s": func_full_qual_name, - "inplace_i": inplace, "input_convention_s": cconv, "outputs": n.outputsSize(), "input_tensor_types_i": input_tensor_types, diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index a5b96c4e37140..fd791f21b4d22 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- + import sys import warnings from collections import OrderedDict @@ -14,10 +15,21 @@ from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils from ._fallback import ORTModuleFallbackException, ORTModuleIOError, _FallbackManager, wrap_exception # noqa: F401 +from ._utils import get_rank + + +def _log_warning(message: str): + """Configure the logger for PythonOp runner according to following rules. + 1. If multiple processes are used, the rank will be appended + to the logger name. + 2. The logger will be disabled for non-zero ranks. + """ + if get_rank() == 0: + warnings.warn(f"[rank-{get_rank()}] {message}") class CustomFuncOpKernelInfo: - """Store the kernel specific information retrieved with the first-time run.""" + """Store the kernel-specific information retrieved with the first-time run.""" def __init__(self, kernel_invoke_id: str): # kernel_invoke_id is a string contains session thread id, op kernel creation time stamp in ms, a random int, @@ -31,9 +43,9 @@ def __init__(self, kernel_invoke_id: str): # reference, may release the content of the tensor before it is needed in backward). Once # `autograd.Function.apply` completes, by checking the existence of the tensor in the saved_tensors, # `_GlobalOpKernelInfoMap` is updated to save the input indices that are saved in context. - # 2. For the subsequent runs, if the input index is in `input_indices_to_save_in_ctx`, the tensor + # 2. For the subsequent runs, if the input index is in `tensor_input_indices_to_save_in_ctx`, the tensor # will be cloned before fed into `autograd.Function.apply` as input. - self.input_indices_to_save_in_ctx: List[int] = [] + self.tensor_input_indices_to_save_in_ctx: Optional[List[int]] = None # To align with PyTorch `ctx.set_materialize_grads(False|True)`` # materialize_grads_config is a map from output index to (device, dtype, shape) of the output tensor, used @@ -41,27 +53,211 @@ def __init__(self, kernel_invoke_id: str): self.materialize_grads: bool = False self.materialize_grads_config: Optional[Dict[int, Tuple[torch.device, torch.dtype, torch.shape]]] = None + # For the tensors generated from ORT backend, there is special handling here: + # 1. For the first time run for the kernel (the uniqueness of the kernel is defined by kernel_invoke_id), + # all such tensors will be cloned (with gradient) in case they are marked as dirty (if not cloned, but marked + # as dirty, PyTorch will complain the tensor is a leaf, should not be used for inplace update). Once + # `autograd.Function.apply` completes, by checking the existence of the tensor in the dirty_tensors, + # `_GlobalOpKernelInfoMap` is updated to save the input indices that are marked as dirty. + # 2. For the subsequent runs, if the input index is in `tensor_input_indices_for_mark_dirty`, the tensor + # will be cloned (with gradient) before fed into `autograd.Function.apply` as input. + self.tensor_input_indices_for_mark_dirty: Optional[List[int]] = None + + # A list of output indices that needs to be clone before returned, due to inplace update analysis. + self.output_indices_for_clone: Optional[List[int]] = None + -# Store the kernel specific information that cannot be retrieved and saved by PyTorch exporter. -# For those infos that can only be retrieved with real run, we try to collect them in the first time run. +# Store the kernel-specific information that cannot be retrieved and saved by PyTorch exporter. +# For the infos that can only be retrieved with real run, we try to collect them in the first time run. # key: kernel_invoke_id, value: CustomFuncOpKernelInfo. _GlobalOpKernelInfoMap: Dict[str, CustomFuncOpKernelInfo] = {} +def _process_inplace_outputs( + kernel_info: CustomFuncOpKernelInfo, + func_name: str, + input_tensors_of_kernel_run: List[torch.Tensor], + all_outputs_of_kernel_run: List[Union[torch.Tensor, any]], + all_outputs_to_tensor_inputs_reuse_map: List[int], + raw_input_tensors_used_inplace: Dict[int, torch.Tensor], + is_backward=False, +): + """Special handling for in-place reusing in forward or backward. + + Args: + kernel_info: kernel-specific information. + func_name: name of the autograd.Function. + input_tensors_of_kernel_run: input tensors used to run the autograd.Function forward/backward. + all_outputs_of_kernel_run: all outputs of the autograd.Function forward/backward. + all_outputs_to_tensor_inputs_reuse_map: a list of the same length of kernel outputs, each element representing + which input index it is reusing. If there is no reuse, the value is -1. + raw_input_tensors_used_inplace: a dict of raw input tensors marked as inplace in + `all_outputs_to_tensor_inputs_reuse_map`, the key is the input index, value is the raw input tensor. + is_backward: indicates if this is backward or forward. + + Procedures: + 1. Detect all outputs to tensor inputs reuse mapping. + 2. Validate the detected inplace_map with the registered inplace_map in ORT. For the output tensor, + 2.0 If the reuse mapping value is the same in both inplace_map and detected inplace_map: + 2.0.1 Most likely, we don't need to do anything, except 2.0.2. + 2.0.2 Conditions: + > During forward run, + > The output tensor is reusing one of input tensors, + > The raw input tensor to be reused given from ORT is copied to run the forward kernels + (for two possible reasons: + a. the first time forward run, all inputs will be copied to detect + `tensor_input_indices_to_save_in_ctx`; + b. for every iteration, the input needs to be cloned because it is in + `tensor_input_indices_to_save_in_ctx`). + + In this case, need to copy the output tensor back to the raw input tensor, to make it compatible with + ORT statistically planned buffer reuse. + 2.1 If the reuse mapping value is NOT equal in both inplace_map and detected inplace_map: + 2.1.1 If the detected reuse input index is -1 (e.g. there is NO buffer reuse for this output), + while user specified reuse input index is NOT -1 (ORT planned the reuse), we raise an error. + 2.1.2 If the detected reuse input index is NOT -1 (e.g. there is buffer reuse for this output), + while user specified reuse input index is -1 (ORT did not plan the reuse). We will try to clone the + output tensor before returning to ORT, to align with ORT's NO Buffer reuse plan; otherwise, once the + input buffer is released by ORT memory planner, the output tensor read/write will be corrupted. + Raise a warning to notify users to update inplace_map explicitly for performance consideration. + 2.1.3 Other cases (for example user gives a wrong mapping index compared with detected ones), raise an + error. + 3. Do copies for 2.1.2 cases. + 4. Do copies for 2.0.2 cases. + """ + + log_prefix = f"{func_name}->{'Backward' if is_backward else 'Forward'}: " + input_tensor_address_list = [t.data_ptr() for t in input_tensors_of_kernel_run] + if is_backward: + input_tensor_address_list = [-1, *input_tensor_address_list] # skip the context input + + is_first_time_init = kernel_info.output_indices_for_clone is None + # If this is the first time run, collect runtime tensor reuse mapping. + if is_first_time_init: + # Procedure 1: Detect all outputs to tensor inputs reuse mapping, according to `all_outputs_of_kernel_run` and + # `input_tensors_of_kernel_run`. + assert len(all_outputs_to_tensor_inputs_reuse_map) == len(all_outputs_of_kernel_run), ( + f"{log_prefix}all_outputs_to_tensor_inputs_reuse_map and kernel run outputs should have the same length." + f"all_outputs_to_tensor_inputs_reuse_map: {all_outputs_to_tensor_inputs_reuse_map}, " + f"kernel run outputs: {all_outputs_of_kernel_run}" + ) + + # Detect all outputs to tensor inputs reuse mapping. + detected_reuse_map = [-1] * (len(all_outputs_of_kernel_run)) + for output_index, arg in enumerate(all_outputs_of_kernel_run): + if not isinstance(arg, torch.Tensor): + continue + if arg.data_ptr() in input_tensor_address_list: + input_index = input_tensor_address_list.index(arg.data_ptr()) + detected_reuse_map[output_index] = input_index + + # Procedure 2: Validate the detected inplace_map with the registered inplace_map in ORT. + output_indices_for_clone = ( + [] + ) # collect the output indices that need to be cloned before returned in case 2.1.2. + for output_index, (detected_inplace_index, inplace_index) in enumerate( + zip(detected_reuse_map, all_outputs_to_tensor_inputs_reuse_map) + ): + if inplace_index == detected_inplace_index: + continue + + # If users register inplace_map (alloc planner will do buffer reuse), + # but detected inplace_map indicates it is NO inplace reusing, we raise an error. + if inplace_index != -1 and detected_inplace_index == -1: + raise RuntimeError( + f"{log_prefix}Fatal: " + f"ONNX Op attribute 'tensor_reuse_map' indicates {output_index}-th output is reusing input " + f"{inplace_index}, but detected inplace_map indicates it is NOT reusing any input. " + "Please update inplace_map explicitly to make it consistent " + f"to avoid undefined behavior due to ORT's memory reuse plan. " + f"inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, " + f"detected inplace_map: {detected_reuse_map}" + ) + + if inplace_index == -1 and detected_inplace_index != -1: + output_indices_for_clone.append(output_index) + continue + + raise RuntimeError( + f"{log_prefix}Fatal: " + f"ONNX Op attribute 'inplace_map' indicates {inplace_index}-th output is reusing " + f"input index {detected_inplace_index}, but detected inplace_map indicates it is reusing " + f"input index {inplace_index}. Please update inplace_map explicitly to avoid undefined behavior " + f"due to memory reuse. inplace_map: {all_outputs_to_tensor_inputs_reuse_map}, " + f"detected inplace_map: {detected_reuse_map}" + ) + + kernel_info.output_indices_for_clone = output_indices_for_clone + + assert kernel_info.output_indices_for_clone is not None + + # Procedure 3: Do copies for 2.1.2 cases. + for output_index in kernel_info.output_indices_for_clone: + _log_warning( + f"{log_prefix}ONNX Op attribute " + f"'tensor_reuse_map' doesn't indicate {output_index}-th output is reusing any input, " + f"but detected inplace_map indicates it is reusing some input index. " + "A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. " + "Please update inplace_map explicitly to avoid such a copy." + ) + all_outputs_of_kernel_run[output_index] = all_outputs_of_kernel_run[output_index].detach().clone() + + # Procedure 4: Do copies for 2.0.2 cases. + if is_backward is False and ( + is_first_time_init + or kernel_info.tensor_input_indices_to_save_in_ctx + or kernel_info.tensor_input_indices_for_mark_dirty + ): + for raw_tensor_input_index, raw_input_tensor in raw_input_tensors_used_inplace.items(): + # raw_input_tensor can be None for backward run, but backward won't go here. + assert isinstance(raw_input_tensor, torch.Tensor) + + # We did not do the check with tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty + # because even for those tensor indices not in + # tensor_input_indices_to_save_in_ctx/tensor_input_indices_for_mark_dirty, we still need to do the + # copy for the first-time run. + if raw_input_tensor.data_ptr() == input_tensor_address_list[raw_tensor_input_index]: + # If the raw input tensor is not copied, we don't need this handling. + continue + + copied = False # for each tensor, we don't do the copy once. + output_indices_reusing_current_raw_input = [ + output_index + for output_index, input_index in enumerate(all_outputs_to_tensor_inputs_reuse_map) + if input_index == raw_tensor_input_index + ] + output_tensor_address = all_outputs_of_kernel_run[output_indices_reusing_current_raw_input[0]].data_ptr() + for output_index in output_indices_reusing_current_raw_input: + assert ( + output_tensor_address == all_outputs_of_kernel_run[output_index].data_ptr() + ), "Outputs reusing the same input tensor should have the same address." + + if not copied: + # Only need a copy once. + raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index]) + _log_warning( + f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}." + "Provide output to input reuse mapping to avoid the copy overhead." + ) + copied = True + + all_outputs_of_kernel_run[output_index] = raw_input_tensor + + def _get_context(forward_tensor_outputs: List[torch.Tensor]) -> Tuple[any, Optional[torch.Tensor]]: """Search for context among all outputs. - Note1: All forward outputs of torch.autograd.Function shared the same gradient function pointer, + Note 1: All forward outputs of torch.autograd.Function shared the same gradient function pointer, so here we just get the first tensor having grad_fn attribute. (https://github.com/PyTorch/PyTorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267) - Note2: Context can be None because NOT all torch.autograd.Function's are differentiable. The function + Note 2: Context can be None because NOT all torch.autograd.Function's are differentiable. The function https://github.com/PyTorch/PyTorch/blob/d701357d921ef167d42c125e65b6f7da6be3ad0f/torch/csrc/autograd/custom_function.cpp#L209? - means if all output of forward function is not differentiable, then grad_fn will be None (not be set). + means if all output of the forward function is not differentiable, then grad_fn will be None (not be set). For example, class Bar(torch.autograd.Function): - # A non-differentiable autograd Function whose forard output + # A non-differentiable autograd Function whose forward output # doesn't have grad_fn attribute. @staticmethod def forward(ctx, x): @@ -85,7 +281,7 @@ def backward(ctx, dy): continue if arg.grad_fn is None: - # For following case, it is possible grad_fn exist, but its value is None, + # For the following case, it is possible grad_fn exists, but its value is None, # so we need to continue to search for the first tensor having a non-None grad_fn. # # >>> w = torch.randn(5, 6) @@ -106,9 +302,10 @@ def backward(ctx, dy): return (ctx, first_tensor_output) -def _finalize_traing_mode_forward( +def _finalize_training_mode_forward( kernel_invoke_id: str, - input_tensors_from_ort: Dict[int, torch.Tensor], + func_name: str, + input_tensors_used_for_fw_run: Dict[int, torch.Tensor], forward_output_tensors: List[Union[torch.Tensor, None]], ): """Complete the epilogue of forward runner for training mode. @@ -120,16 +317,25 @@ def _finalize_traing_mode_forward( Things to do: 1. Try to get context from forward output tensors. - 2. Remove the gradient functions between current autograd.Function and its input's gradient function, because + 2. Remove the gradient functions between the current autograd.Function and its input's gradient function, because in ORT we don't depend on PyTorch's autograd engine. 3. Register the current autograd.Function's gradient function into our PyNodeSharedPointerPool. - 4. Save kernel specific information into _GlobalOpKernelInfoMap in the first-time kernel run. + 4. Save kernel-specific information into _GlobalOpKernelInfoMap in the first-time kernel run. """ ctx, tensor_owning_ctx = _get_context(forward_output_tensors) + kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] + # ctx being None in training mode means the forward function is not differentiable, so backward is not needed. if ctx is None: + # If this is the first time run, collect kernel-specific information. + if kernel_info.tensor_input_indices_to_save_in_ctx is None: + kernel_info.tensor_input_indices_to_save_in_ctx = [] + + if kernel_info.tensor_input_indices_for_mark_dirty is None: + kernel_info.tensor_input_indices_for_mark_dirty = [] + return None # Filter out the None in the saved_tensors. @@ -137,19 +343,20 @@ def _finalize_traing_mode_forward( ctx.fw_kernel_invoke_id = kernel_invoke_id - # If this is the first time run, collect kernel specific information. - if kernel_invoke_id not in _GlobalOpKernelInfoMap: - kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) - _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info + # If this is the first time run, collect kernel-specific information. + if kernel_info.tensor_input_indices_to_save_in_ctx is None: + kernel_info.tensor_input_indices_to_save_in_ctx = [] if len(saved_tensors): - # Check tensors generated by ORT is in the saved_tensors or not. + # Check tensors generated by ORT are in the saved_tensors or not. # If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap. - kernel_info.input_indices_to_save_in_ctx = [ - arg_index - for arg_index, tensor in input_tensors_from_ort.items() + kernel_info.tensor_input_indices_to_save_in_ctx = [ + tensor_input_index + for tensor_input_index, tensor in input_tensors_used_for_fw_run.items() if any(tensor is saved_tensor for saved_tensor in saved_tensors) ] - warnings.warn("Add input index to _GlobalOpKernelInfoMap, to avoid extra copy in every iteration.") + _log_warning( + f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to avoid extra copy in every iteration." + ) kernel_info.materialize_grads = torch_interop_utils.get_materialize_grads(tensor_owning_ctx) kernel_info.materialize_grads_config = OrderedDict() if kernel_info.materialize_grads: @@ -161,6 +368,22 @@ def _finalize_traing_mode_forward( tensor.shape, ) + if kernel_info.tensor_input_indices_for_mark_dirty is None: + kernel_info.tensor_input_indices_for_mark_dirty = [] + # Check tensors generated by ORT are marked as dirty(for inplace update) or not. + # If yes, save the input index of the tensor in the _GlobalOpKernelInfoMap. + are_tensors_marked_as_dirty = torch_interop_utils.are_tensors_marked_as_dirty( + tensor_owning_ctx, [t for t in input_tensors_used_for_fw_run.values()] + ) + kernel_info.tensor_input_indices_for_mark_dirty = [ + tensor_input_index + for is_dirty, (tensor_input_index, tensor) in zip( + are_tensors_marked_as_dirty, input_tensors_used_for_fw_run.items() + ) + if is_dirty is True + ] + _log_warning(f"{func_name}: Add input index to _GlobalOpKernelInfoMap, to support leaf node do inplace update.") + # FORWARD BACKWARD FUNCTION CONNECTIONS # input_1 (leaf, constructed by from_dlpack) <----reference---- AccumulateGrad gradient function # ↓ ↑ @@ -188,8 +411,9 @@ def call_python_forward_function( requires_grad_flags: List[bool], tensor_type_flags: List[int], is_training_mode: bool, - inplace: bool, + inplace_map: List[int], kernel_invoke_id: str, + func_name: Union[bytes, str], *args, ): """ @@ -206,93 +430,119 @@ def call_python_forward_function( requires_grad_flags: requires_grad_flags[i] indicates if the i-th arg needs gradient. tensor_type_flags: tensor_type_flags[i] indicates the type of the i-th arg, 0 - non-tensor, 1 - tensor. is_training_mode: indicates if this model is running under training mode. - inplace: indicates if args can be modified inside the custom function. + inplace_map: a list of the same length of kernel outputs, each element represents which input index + it is reusing. If there is no reuse, the value is -1. args: inputs to "backward_function". """ - def generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode, is_inplace): - if is_training_mode and tensor_flag and grad_flag and is_inplace: - # "multiply one" helps change the torch tensor's is_leaf to False. - # This is required when the torch tensor is updated in-place during forward pass. - # We cannot use view here, because PyTorch handles grad_fn for view differently. - non_leaf_arg = arg * 1 - return non_leaf_arg - else: - return arg - try: - wrapped_args = [] - tensor_input_args_map = OrderedDict() + func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name + # If this is the first time run, collect runtime tensor reuse mapping. + if kernel_invoke_id not in _GlobalOpKernelInfoMap: + kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) + _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info + + kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] - # Be noted: in inference mode, we won't insert any information into _GlobalOpKernelInfoMap, because ctx - # will always be None in the first time run. - input_indices_to_save_in_ctx = None # Uninitialized - if kernel_invoke_id in _GlobalOpKernelInfoMap: - input_indices_to_save_in_ctx = _GlobalOpKernelInfoMap[kernel_invoke_id].input_indices_to_save_in_ctx + tensor_input_indices_to_save_in_ctx = kernel_info.tensor_input_indices_to_save_in_ctx + tensor_input_indices_for_mark_dirty = kernel_info.tensor_input_indices_for_mark_dirty - for arg_index, (grad_flag, tensor_flag, arg) in enumerate(zip(requires_grad_flags, tensor_type_flags, args)): + # Collect the tensor address for all inputs used for run forward, used for reuse detection. + tensor_input_index = 0 + # If the input is reused, we need to save the raw input tensor for special handling. + raw_input_tensors_used_inplace = OrderedDict() # Orders matter here. + input_tensors_used_for_fw_run = OrderedDict() # Orders matter here. + + wrapped_args = [] + for _, (grad_flag, tensor_flag, arg) in enumerate(zip(requires_grad_flags, tensor_type_flags, args)): if tensor_flag: # Assume it's a DLPack tensor and convert it to PyTorch tensor. + wrapped_arg = from_dlpack(arg) + + if tensor_input_index in inplace_map: + raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg + # Note1: - # If it's first-time kernel invocation, input_indices_to_save_in_ctx is None, we do the - # copy for all tensor. Otherwise, we only copy the tensors whose indices are in - # input_indices_to_save_in_ctx. - # + # If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the + # copy for all tensors. Otherwise, we only copy the tensors whose indices are in + # tensor_input_indices_to_save_in_ctx. # Note2: - # For inference mode, we don't need do the copy because ctx will be None, + # For inference mode, we don't need to do the copy because ctx will be None, # so nothing will be saved for ctx. if is_training_mode and ( - input_indices_to_save_in_ctx is None or arg_index in input_indices_to_save_in_ctx + tensor_input_indices_to_save_in_ctx is None + or tensor_input_index in tensor_input_indices_to_save_in_ctx ): - wrapped_arg = from_dlpack(arg).detach().clone() - else: - wrapped_arg = from_dlpack(arg) + wrapped_arg = wrapped_arg.detach().clone() # Only requires gradient when running under training mode # and the associated tensor has grad_flag=True (i.e., # "requires_grad=True" in the original PyTorch script). wrapped_arg.requires_grad = is_training_mode and grad_flag + + # Note3: + # If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the + # mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in + # tensor_input_indices_for_mark_dirty. + if is_training_mode and ( + tensor_input_indices_for_mark_dirty is None + or tensor_input_index in tensor_input_indices_for_mark_dirty + ): + # To fix this issue: + # "a leaf Variable that requires grad has been used in an in-place operation." + with torch.set_grad_enabled(True): + wrapped_arg = wrapped_arg.clone() + wrapped_args.append(wrapped_arg) - tensor_input_args_map[arg_index] = wrapped_arg + input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg + tensor_input_index += 1 else: # Use non-tensor as is. It's a PyObject*. wrapped_args.append(arg) with torch.set_grad_enabled(is_training_mode): - # Another level of wrap to avoid requires_grad=True for leaf variables. - new_wrapped_args = list( - generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode, inplace) - for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, wrapped_args) - ) - # Run autograd.Function.apply(...). - # TODO(pengwa): looks we are assuming all outputs will be either Tensor or None. + # TODO(pengwa): looks like we are assuming all outputs will be either Tensor or None. # We should revisit if it is possible to support other types of output, for example int, or, etc. - # But that might also requires some work in backend. - result = forward_function(*new_wrapped_args) + # But that might also require some work in backend. + result = forward_function(*wrapped_args) - # Extract results as DLPack tensors plus autograd context. Also skips all None values. + results = [] if isinstance(result, torch.Tensor): - ctx = None - if is_training_mode: - ctx = _finalize_traing_mode_forward(kernel_invoke_id, tensor_input_args_map, [result]) - unwrapped_values = [ctx, to_dlpack(result)] + results = [result] elif isinstance(result, (tuple, list)): - ctx = None - if is_training_mode: - ctx = _finalize_traing_mode_forward(kernel_invoke_id, tensor_input_args_map, result) - wrapped = [ctx] - wrapped.extend(list(to_dlpack(value) if value is not None else None for value in result)) - # Inside the returned list, first element is context and the rest - # are DLPack tensors. - unwrapped_values = wrapped + results = [r for r in result] else: raise wrap_exception( ORTModuleIOError, TypeError(f"ORTModule does not support the following model output type {type(result)}."), ) - return tuple(unwrapped_values) + + ctx = None + if is_training_mode: + ctx = _finalize_training_mode_forward( + kernel_invoke_id, func_name, input_tensors_used_for_fw_run, results + ) + + final_rets = [ctx] + final_rets.extend(results) + + _process_inplace_outputs( + kernel_info, + func_name, + input_tensors_used_for_fw_run.values(), + final_rets, + inplace_map, + raw_input_tensors_used_inplace, + ) + + dlpacks = [final_rets[0]] + dlpacks.extend(list(to_dlpack(value) if value is not None else None for value in final_rets[1:])) + + # Inside the returned list, the first element is context and the rest + # are DLPack tensors. + return tuple(dlpacks) except Exception as e: # Flush buffers. Otherwise, calling this from C++ may lose them. print("Exception happens when running ", forward_function) @@ -306,8 +556,9 @@ def call_python_backward_function( requires_grad_flags: List[bool], tensor_type_flags: List[int], is_training_mode: bool, - inplace: bool, + inplace_map: List[int], kernel_invoke_id: str, + func_name: Union[bytes, str], *args, ): """ @@ -319,11 +570,13 @@ def call_python_backward_function( Args: backward_function: pointer to autograd.Function.backward (e.g., MyReLU.backward). requires_grad_flags: requires_grad_flags[i] indicates if the i-th arg needs gradient. - tensor_type_flags: tensor_type_flagsi] indicates the type of the i-th arg. + tensor_type_flags: tensor_type_flags[i] indicates the type of the i-th arg. is_training_mode: indicates if this model is running under training mode. - inplace: indicates if args can be modified inside the custom function. + inplace_map: a list of the same length of kernel outputs, each element represents which input index + it is reusing. If there is no reuse, the value is -1. args: inputs to "backward_function". """ + func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name with torch.no_grad(): def wrap_all_outputs(result): @@ -338,6 +591,13 @@ def wrap_all_outputs(result): ) try: + # If this is the first time run, collect runtime tensor reuse mapping. + if kernel_invoke_id not in _GlobalOpKernelInfoMap: + kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) + _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info + + kernel_info = _GlobalOpKernelInfoMap[kernel_invoke_id] + # Backward inputs should not require gradients. assert all(grad_flag == 0 for grad_flag in requires_grad_flags) @@ -345,6 +605,12 @@ def wrap_all_outputs(result): ctx = args[0] fw_kernel_invoke_id = ctx.fw_kernel_invoke_id wrapped_args = [] + + # Collect the tensor address for all inputs used for run backward, used for reuse detection. + tensor_input_index = 1 # skip the context input + # If input is reused, we need to save the raw input tensor for special handling. + raw_input_tensors_used_inplace = OrderedDict() # Orders matter here. + input_tensors_used_for_bw_run = OrderedDict() # Orders matter here. for grad_input_index, (grad_flag, tensor_flag, arg) in enumerate( zip(requires_grad_flags, tensor_type_flags, args) ): @@ -362,12 +628,19 @@ def wrap_all_outputs(result): # Assume it's a DLPack tensor# and convert it to PyTorch tensor. wrapped_arg = from_dlpack(arg) + if grad_input_index in inplace_map: + raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg + + input_tensors_used_for_bw_run[tensor_input_index] = wrapped_arg + if wrapped_arg is not None: # Only requires gradient when running under training mode # and the associated tensor has grad_flag=True (i.e., # "requires_grad=True" in the original PyTorch script). wrapped_arg.requires_grad = is_training_mode and grad_flag + wrapped_args.append(wrapped_arg) + tensor_input_index += 1 else: # Use non-tensor as is. It's a PyObject*. wrapped_args.append(arg) @@ -386,6 +659,16 @@ def wrap_all_outputs(result): TypeError(f"ORTModule does not support the following model output type {type(result)}."), ) + _process_inplace_outputs( + kernel_info, + func_name, + input_tensors_used_for_bw_run.values(), + result, + inplace_map, + raw_input_tensors_used_inplace, + is_backward=True, + ) + wrapped_returned_args = wrap_all_outputs(result) torch_interop_utils.unregister_grad_fn(id(ctx)) diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py index 17756600d601e..301071f6de44c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- + from typing import Dict, List, Optional, Tuple, Union import torch @@ -234,16 +235,16 @@ def _create_weight_retrieval_pythonop( func_full_qual_name: str, input_name: str, output_names: List[str], - STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, - STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE: List[int], + pull_weight_trigger_output_dtype: int, + pull_weight_trigger_output_shape: List[int], ) -> Tuple[ValueInfoProto, NodeProto]: """This function is used to create a weight retrieving PythonOp.""" offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params) new_input = helper.make_tensor_value_info( - input_name, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE + input_name, pull_weight_trigger_output_dtype, pull_weight_trigger_output_shape ) - output_rank_for_pull_weight_trigger = len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE) - output_dtype_for_pull_weight_trigger = STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE + output_rank_for_pull_weight_trigger = len(pull_weight_trigger_output_shape) + output_dtype_for_pull_weight_trigger = pull_weight_trigger_output_dtype output_tensor_ranks = [ output_rank_for_pull_weight_trigger, ] * offload_param_count @@ -253,10 +254,9 @@ def _create_weight_retrieval_pythonop( node_attributes = { "comment": "", - "inplace": 0, "input_convention": "d", - "input_tensor_ranks": [len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)], - "input_tensor_types": [STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE], + "input_tensor_ranks": [len(pull_weight_trigger_output_shape)], + "input_tensor_types": [pull_weight_trigger_output_dtype], "output_tensor_ranks": output_tensor_ranks, "output_tensor_types": output_tensor_types, "training_mode": 1, diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc index e55aacb2334b2..d36720100e57a 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/torch_interop_utils.cc @@ -150,6 +150,34 @@ bool get_materialize_grads(at::Tensor target) { return py_fn->materialize_grads; } +std::vector are_tensors_marked_as_dirty(at::Tensor target, std::vector tensors_to_check) { + torch::autograd::AutogradMeta* autograd_meta = torch::autograd::impl::get_autograd_meta(target); + const auto& grad_fn = autograd_meta->grad_fn_; + auto py_node_fn = dynamic_cast(grad_fn.get()); + TORCH_CHECK(py_node_fn != nullptr, "grad_fn is not PyNode type."); + THPFunction* py_fn = (THPFunction*)py_node_fn->obj; + std::vector are_tensors_marked_dirty(tensors_to_check.size(), false); + if (!py_fn->dirty_tensors) + return are_tensors_marked_dirty; + + Py_ssize_t num_dirty = PyTuple_GET_SIZE(py_fn->dirty_tensors); + for (const auto j : c10::irange(tensors_to_check.size())) { + bool is_tensor_marked_dirty = false; + for (const auto i : c10::irange(num_dirty)) { + PyObject* obj = PyTuple_GET_ITEM(py_fn->dirty_tensors, i); + const auto& tensor = THPVariable_Unpack(obj); + if (tensor.is_same(tensors_to_check[j])) { + is_tensor_marked_dirty = true; + break; + } + } + + are_tensors_marked_dirty[j] = is_tensor_marked_dirty; + } + + return are_tensors_marked_dirty; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("register_grad_fn_and_remove_from_autograd", ®ister_grad_fn_and_remove_from_autograd, "Increase grad_fn shared pointer reference."); @@ -158,4 +186,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("clear_grad_fns_for_next_edges", &clear_grad_fns_for_next_edges, "Remove reference on next edges' gradient functions."); m.def("get_materialize_grads", &get_materialize_grads, "Return whether materialize_grads is enabled or not."); + m.def("are_tensors_marked_as_dirty", &are_tensors_marked_as_dirty, "Return whether the tensors are marked dirty or not."); } diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index ae9bc4328cb26..958c7d94c4241 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -1549,7 +1549,7 @@ def _run_step(model, input): count += 1 if index == 0: - assert count == 1 + assert count == 2 else: assert count == 0 @@ -1717,3 +1717,97 @@ def forward(self, model_input): ).train() _ = ortmodule(torch.randn(output_size, dtype=torch.float)) _check_pythonop_shape(ortmodule) + + +def test_python_op_return_persistent_param_as_value(): + """Some PythonOp return values that are still used by PyTorch computation. This test makes sure that ORTModule + will not release/erase the storage of those return values during tear down OrtValue of the corresponding PythonOp + return values. + """ + + class SimplePassThrough(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.detach() + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + class GeluWithExternalOutput(torch.autograd.Function): + @staticmethod + def forward(ctx, x, bias_param): + ctx.save_for_backward(x) + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))), bias_param.detach() + + @staticmethod + def backward(ctx, *grad_outputs): + (x,) = ctx.saved_tensors + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + g = ff * grad_outputs[0] + return g, grad_outputs[1] + + class TestLayer(torch.nn.Module): + def __init__(self, output_size): + super().__init__() + self.relu = GeluWithExternalOutput.apply + self._output_size = output_size + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) + self.w = Parameter( + torch.empty(output_size, output_size, device=torch.cuda.current_device(), dtype=torch.float) + ) + with torch.no_grad(): + self.bias.uniform_() + self.w.uniform_() + + def forward(self, model_input): + activation0 = torch.add(model_input, 0.4) + activation1 = activation0.view(self._output_size, -1) + + # Returned detached_bias_param Tensor shares the same storage with self.bias + # We are testing to make sure ORT will not erase the storage of self.bias during tear down OrtValue as + # the returned value of the SimplePassThrough PythonOp. + detached_bias_param = SimplePassThrough.apply(self.bias) + relu_out, detached_bias_param = self.relu(activation1, detached_bias_param) + activation2 = torch.add(relu_out, self.bias) + activation3 = torch.add(activation2, detached_bias_param) + activation3 = torch.matmul(self.w, activation3) + activation4 = torch.div(activation3, 1000) + return activation4 + + class TestModule(torch.nn.Module): + def __init__(self, output_size) -> None: + super().__init__() + self.layers = torch.nn.ModuleList([TestLayer(output_size) for i in range(6)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for layer in self.layers: + x = x.view(-1) + x = torch.nn.functional.relu(layer(x)) + return x + + device = "cuda" + output_size = 1024 + pt_model = TestModule(output_size).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def _run_step(model, input): + loss = model(input).sum() + loss.backward() + return loss + + for _ in range(5): + input = torch.randn(output_size, device=device, dtype=torch.float) + _run_step(pt_model, input) + _run_step(ort_model, input) + + pt_params = {n: p for n, p in pt_model.named_parameters()} + for name, param in ort_model.named_parameters(): + assert_values_are_close(param, pt_params[name], rtol=1e-04, atol=1e-3) + if param.grad is not None: + assert pt_params[name].grad is not None, f"pt param.grad is None for {name}" + assert_values_are_close(param.grad, pt_params[name].grad, rtol=1e-04, atol=1e-3) + else: + assert pt_params[name].grad is None diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index e1d4be24861f5..a31fa5d850e59 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -49,7 +49,6 @@ std::vector> CreateOrtValueArgs(OpKernelContext* context void PythonOpBase::Init(const OpKernelInfo& info) { ORT_THROW_IF_ERROR(info.GetAttr("func_name", &name_)); - ORT_THROW_IF_ERROR(info.GetAttr("inplace", &inplace_)); is_training_mode_ = static_cast(info.GetAttrOrDefault("training_mode", static_cast(0))); ORT_THROW_IF_ERROR(info.GetAttr("input_convention", &input_convention_)); @@ -117,6 +116,9 @@ void PythonOpBase::Init(const OpKernelInfo& info) { // Output tensors. ORT_THROW_IF_ERROR(info.GetAttrs("output_tensor_types", output_tensor_types_)); + all_output_to_tensor_input_reuse_map_ = + info.GetAttrsOrDefault("tensor_reuse_map", std::vector((info.node().OutputDefs().size()), -1)); + CreateConstArgs(); CreateArgPositions(); @@ -141,17 +143,18 @@ void PythonOpBase::RunForward(OpKernelContext* context, std::vector> args = CreateOrtValueArgs(context, 0, context->InputCount()); // Invoke Python calls. TorchProxy::GetInstance().Forward( + name_, OrtTorchFunctionPool::GetInstance().GetForwardCore(name_), input_requires_grads_, args, arg_positions_, const_arg_set_.GetDataPtrs(), const_arg_set_.GetPositions(), - diff_ctx, - returned_ortvalues, is_training_mode_, - inplace_ != 0, - kernel_invoke_id_); + all_output_to_tensor_input_reuse_map_, + kernel_invoke_id_, + diff_ctx, + returned_ortvalues); const size_t returned_output_count = 1 + returned_ortvalues.size(); const size_t kernel_output_count = static_cast(context->OutputCount()); @@ -291,14 +294,32 @@ void PythonOpBase::SetContextOutput(OpKernelContext* context, void* diff_ctx) co void PythonOpBase::SetOtherOutputs(OpKernelContext* context, std::vector& returned_ortvalues) const { auto* ctx_internal = reinterpret_cast(context); + ORT_ENFORCE(returned_ortvalues.size() == all_output_to_tensor_input_reuse_map_.size() - 1, "PythonOp output count mismatch inplace map count.", + returned_ortvalues.size(), " != ", all_output_to_tensor_input_reuse_map_.size() - 1); for (size_t i = 0; i < returned_ortvalues.size(); ++i) { + size_t output_index = i + 1; + if (all_output_to_tensor_input_reuse_map_[output_index] != -1) { + const void* tensor_address = returned_ortvalues[i].Get().DataRaw(); + const void* input_tensor_address = context->Input(all_output_to_tensor_input_reuse_map_[output_index])->DataRaw(); + ORT_ENFORCE(tensor_address == input_tensor_address, + "PythonOp inplace tensor address mismatch, output index: ", output_index, ", input index: ", + all_output_to_tensor_input_reuse_map_[output_index]); + } + + // Notes: if the buffer is created, managed by PyTorch, converted to OrtValue through dlpack here, + // but also be used outside ORT later, we don't need to be concerned about + // "when the buffer of returned_ortvalues[i] is erased by ORT during releasing that OrtValue causing + // the PyTorch code still using that buffer will be failed". + // In this case, the created OrtValue's destructor will not release the buffer, + // instead it will release a tensor pointing to that buffer, where PyTorch will decide whether to release + // the buffer or not, if the tensor storage is not used by any other tensors + // (https://github.com/PyTorch/PyTorch/blob/ac603bc2f8ffac8fc061cfb99e77537464da4b18/aten/src/ATen/DLConvertor.cpp#L257C25-L257C29). ORT_THROW_IF_ERROR(ctx_internal->SetOutputMLValue(static_cast(i + 1), returned_ortvalues[i])); } } void PythonOpGradBase::Init(const OpKernelInfo& info) { ORT_THROW_IF_ERROR(info.GetAttr("func_name", &name_)); - ORT_THROW_IF_ERROR(info.GetAttr("inplace", &inplace_)); ORT_THROW_IF_ERROR(info.GetAttrs("input_tensor_types", input_tensor_types_)); ORT_THROW_IF_ERROR(info.GetAttr("output_convention", &output_convention_)); ORT_THROW_IF_ERROR(info.GetAttrs("output_tensor_types", output_tensor_types_)); @@ -306,6 +327,24 @@ void PythonOpGradBase::Init(const OpKernelInfo& info) { ORT_ENFORCE(output_tensor_types_.size() == output_tensor_requires_grads_.size(), "backward tensor output count mismatch"); + std::vector tensor_output_to_tensor_input_alias_map = + info.GetAttrsOrDefault("tensor_reuse_map", + std::vector((info.node().OutputDefs().size()), -1)); + all_output_to_tensor_input_reuse_map_.clear(); + all_output_to_tensor_input_reuse_map_.reserve(output_convention_.size()); + size_t tensor_output_index = 0; + for (size_t i = 0; i < output_convention_.size(); ++i) { + if (output_convention_[i] == 'd') { + all_output_to_tensor_input_reuse_map_.push_back( + tensor_output_to_tensor_input_alias_map[tensor_output_index] == -1 + ? -1 + : tensor_output_to_tensor_input_alias_map[tensor_output_index]); + ++tensor_output_index; + } else { + all_output_to_tensor_input_reuse_map_.push_back(-1); + } + } + SetPositions(); kernel_invoke_id_ = GetInvokeIdString(this); @@ -314,7 +353,7 @@ void PythonOpGradBase::Init(const OpKernelInfo& info) { void PythonOpGradBase::RunBackward(OpKernelContext* context, std::vector& returned_ortvalues) const { std::vector> args = CreateOrtValueArgs(context, 1, context->InputCount() - 1); - // This is called "const" because that's how Pytorch calls all non-tensor inputs. + // This is called "const" because that's how PyTorch calls all non-tensor inputs. const Tensor* context_id_tensor = context->Input(0); ORT_ENFORCE(context_id_tensor, "Context ID (first input) should not be null."); const int64_t* context_index_ptr = context_id_tensor->template Data(); @@ -323,15 +362,15 @@ void PythonOpGradBase::RunBackward(OpKernelContext* context, std::string err; TorchProxy::GetInstance().Backward( - OrtTorchFunctionPool::GetInstance() - .GetBackwardCore(name_), + name_, + OrtTorchFunctionPool::GetInstance().GetBackwardCore(name_), args, arg_positions_, const_args, const_arg_positions_, - returned_ortvalues, - inplace_ != 0, - kernel_invoke_id_); + all_output_to_tensor_input_reuse_map_, + kernel_invoke_id_, + returned_ortvalues); OrtTorchFunctionPool::GetInstance().UnregisterContext(*context_index_ptr); } @@ -343,6 +382,29 @@ void PythonOpGradBase::SetOutputs(OpKernelContext* context, std::vectorInput(all_output_to_tensor_input_reuse_map_[i]); + if (input_tensor) { + ORT_ENFORCE(input_tensor, "PythonOpGrad input tensor should not be null. input index: ", all_output_to_tensor_input_reuse_map_[i]); + + // Be noted: PythonOpGrad's input won't be non-tensor. + ORT_ENFORCE(all_output_to_tensor_input_reuse_map_[i] < context->InputCount(), "PythonOpGrad inplace tensor index out of bound."); + const void* tensor_address = returned_ortvalues[i].Get().DataRaw(); + + const void* input_tensor_address = input_tensor->DataRaw(); + ORT_ENFORCE(tensor_address == input_tensor_address, + "PythonOpGrad inplace tensor address mismatch, output index: ", i, ", input index: ", all_output_to_tensor_input_reuse_map_[i]); + } + } + + // Notes: if the buffer is created, managed by PyTorch, converted to OrtValue through dlpack here, + // but also be used outside ORT later, we don't need to be concerned about + // "when the buffer of returned_ortvalues[i] is erased by ORT during releasing that OrtValue causing + // the PyTorch code still using that buffer will be failed". + // In this case, the created OrtValue's destructor will not release the buffer, + // instead it will release a tensor pointing to that buffer, where PyTorch will decide whether to release + // the buffer or not, if the tensor storage is not used by any other tensors + // (https://github.com/PyTorch/PyTorch/blob/ac603bc2f8ffac8fc061cfb99e77537464da4b18/aten/src/ATen/DLConvertor.cpp#L257C25-L257C29). ORT_THROW_IF_ERROR(ctx_internal->SetOutputMLValue(tensor_output_index, returned_ortvalues.at(i))); } ++tensor_output_index; @@ -356,11 +418,11 @@ void PythonOpGradBase::SetPositions() { ORT_ENFORCE(const_arg_positions_.size() == 0); ORT_ENFORCE(arg_positions_.size() == 0); - // Pytorch's autograd context is the first (indexed by 0) input of the called Python function. + // PyTorch's autograd context is the first (indexed by 0) input of the called Python function. // Note that here we will call autograd.Function.backward(ctx, tensor0, tensor1, ...). const_arg_positions_ = {0}; - // The rest inputs are just Pytorch tensors. + // The rest inputs are just PyTorch tensors. arg_positions_.resize(input_tensor_types_.size()); for (size_t i = 0; i < arg_positions_.size(); ++i) { // i-th tensor is the (i+1)-th input of autograd.Function.backward. diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h index 1657bf49ea2e6..d4a53a223abf1 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h @@ -106,7 +106,7 @@ class PythonOpBase { // Name of containing class. For example, MyReLU. std::string name_; - int64_t inplace_; + std::vector all_output_to_tensor_input_reuse_map_; std::string input_convention_; bool is_training_mode_; // input_requires_grads_[i] indicates if the i-th inputs of apply() should have gradient. @@ -179,7 +179,7 @@ class PythonOpGradBase { protected: // Name of containing class. For example, MyReLU. std::string name_; - int64_t inplace_; + // Input types of MyReLU.backward(...). std::vector input_tensor_types_; @@ -190,6 +190,9 @@ class PythonOpGradBase { std::vector arg_positions_; std::vector const_arg_positions_; + // Memory reuse map for all outputs. + std::vector all_output_to_tensor_input_reuse_map_; + private: void SetPositions();