Skip to content

Commit

Permalink
Support inplace update for PythonOp/Grad (microsoft#17687)
Browse files Browse the repository at this point in the history
### Support inplace update for PythonOp/Grad

This PR is based on another PR
microsoft#17685 branch, to make it
easier to review.

With PR: PR microsoft#17685, By
default all PythonOp inputs/outputs are assumed to not be inplaced, if
during run, we found some inplace update happens (by checking output
data address with all inputs data address), we add clone before set it
as PythonOp/Grad's outputs. In this case, results are correct, but
implicit copies overheads are introduced.

This PR allow users to define output input reuse map, to let ORT know
how to do the reuse map, avoid such unnecessary copies.
  • Loading branch information
pengwa authored and kleiti committed Mar 22, 2024
1 parent 528d4ef commit 95307fe
Show file tree
Hide file tree
Showing 21 changed files with 590 additions and 78 deletions.
3 changes: 3 additions & 0 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxr
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT})
if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT})
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
onnxruntime_add_include_to_target(onnxruntime_optimizer Python::Module)
endif()
endif()
if (onnxruntime_ENABLE_TRITON)
target_link_libraries(onnxruntime_optimizer PRIVATE nlohmann_json::nlohmann_json)
Expand Down
32 changes: 30 additions & 2 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ class PlannerImpl {
return false;
}

const auto& alias_map = ci.kernel_def->Alias();
const auto alias_map = GetAliasMap(node, ci);
auto input_args = node.InputDefs();
for (auto& pair : alias_map) {
if (pair.second == output_arg_num) {
Expand Down Expand Up @@ -829,6 +829,34 @@ class PlannerImpl {
return p_provider->GetOrtDeviceByMemType(utils::IsInputOnCpu(node, &kernel_create_info, input_index) ? OrtMemTypeCPUInput : OrtMemTypeDefault);
}

std::vector<std::pair<int, int>> GetAliasMap(const Node& node, const KernelCreateInfo& kernel_create_info) {
ORT_ENFORCE(kernel_create_info.kernel_def != nullptr, "KernelDef is null for node: ", node.Name());
#ifdef ENABLE_TRAINING_TORCH_INTEROP
if ((node.OpType().compare("PythonOp") == 0 || node.OpType().compare("PythonOpGrad") == 0) &&
node.Domain() == kMSDomain) {
const auto& attrs = node.GetAttributes();
auto attr_it = attrs.find("tensor_reuse_map");
if (attr_it != attrs.end()) {
const auto& inplace_map = attr_it->second.ints();
std::vector<std::pair<int, int>> alias_map;
alias_map.reserve(inplace_map.size());
for (int i = 0; i < inplace_map.size(); ++i) {
int output_index = i;
int input_index = inplace_map[i];
if (input_index == -1) {
// skip because no reuse for this output
continue;
}
alias_map.emplace_back(std::make_pair(input_index, output_index));
}
return alias_map;
}
}
#endif

return kernel_create_info.kernel_def->Alias();
}

void GeneratePlanForWeightsHelper(const GraphViewer& graph_viewer,
const InitializedTensorSet& weights,
const KernelCreateInfoMap& kernel_create_info_map,
Expand Down Expand Up @@ -1084,7 +1112,7 @@ class PlannerImpl {
}

bool found_reusable = false;
const auto& alias_map = ci.kernel_def->Alias();
const auto alias_map = GetAliasMap(*node, ci);
auto input_args = node->InputDefs();
for (auto* input_arg : input_args) {
OrtValueIndex input_idx_global{};
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,10 +2385,10 @@ def _infer_PythonOp(self, node): # noqa: N802
output_tensor_ranks = get_attribute(node, "output_tensor_ranks")
assert output_tensor_ranks

from onnxruntime.training.ortmodule._custom_autograd_function_exporter import PythonOpShapeInferStore
from onnxruntime.capi._pybind_state import get_shape_inference_function

func_name = get_attribute(node, "func_name").decode()
shape_inferer = PythonOpShapeInferStore.get_shape_infer(func_name)
shape_inferer = get_shape_inference_function(func_name)

# Set the context output separately.
# The first output is torch.autograd.Function''s context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ void OrtTorchFunctionPool::RegisterTorchAutogradFunction(
RegisterEntry(mutex_, key, backward.get(), backward_core_pool_);
}

void OrtTorchFunctionPool::RegisterShapeInferenceFunction(const std::string& key,
PyObject* obj) {
RegisterEntry(mutex_, key, obj, shape_inference_function_pool_);
}

void OrtTorchFunctionPool::RegisterInputAliasFunction(const std::string& key,
PyObject* obj) {
RegisterEntry(mutex_, key, obj, input_alias_function_pool_);
}

static void RegisterEntry(
std::mutex& mutex,
PyObject* obj,
Expand Down Expand Up @@ -153,6 +163,26 @@ PyObject* OrtTorchFunctionPool::GetBackwardCore(const std::string& key) {
return iter->second.get();
}

std::optional<PyObject*> OrtTorchFunctionPool::TryGettingShapeInferenceFunction(const std::string& key) {
ORT_ENFORCE(!key.empty(), "Cannot be empty string.");
std::lock_guard<std::mutex> lock(mutex_);
auto iter = shape_inference_function_pool_.find(key);
if (iter != shape_inference_function_pool_.end()) {
return iter->second.get();
}
return std::nullopt;
}

std::optional<PyObject*> OrtTorchFunctionPool::TryGettingInputAliasFunction(const std::string& key) {
ORT_ENFORCE(!key.empty(), "Cannot be empty string.");
std::lock_guard<std::mutex> lock(mutex_);
auto iter = input_alias_function_pool_.find(key);
if (iter != input_alias_function_pool_.end()) {
return iter->second.get();
}
return std::nullopt;
}

void OrtTorchFunctionPool::RegisterMiscellaneousConstInput(PyObject* obj) {
ORT_ENFORCE(obj, "Cannot register NULL reference input.");
const void* address = static_cast<const void*>(obj);
Expand Down Expand Up @@ -205,6 +235,8 @@ void OrtTorchFunctionPool::UnRegisterGlobalFunctions() {
void OrtTorchFunctionPool::UnRegisterModelSpecificFunctions() {
forward_core_pool_.clear();
backward_core_pool_.clear();
shape_inference_function_pool_.clear();
input_alias_function_pool_.clear();
miscellaneous_const_input_pool_.clear();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ class OrtTorchFunctionPool final {
// 2. Caller of GetBackwardCore should not decrease the reference count of the returned object.
PyObject* GetBackwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOpGrad.

// Shape inference function is used to infer output shape of a PythonOp.
void RegisterShapeInferenceFunction(const std::string& key, PyObject* obj);
// Return a borrowed reference to the stored Python function, if it exists; otherwise, return nullptr.
std::optional<PyObject*> TryGettingShapeInferenceFunction(const std::string& key);

// Input alias function is used to infer memory reuse map of a PythonOp.
void RegisterInputAliasFunction(const std::string& key, PyObject* obj);
// Return a borrowed reference to the stored Python function, if it exists; otherwise, return nullptr.
std::optional<PyObject*> TryGettingInputAliasFunction(const std::string& key);

// Autograd function may take input of "non-tensor && non int/float && non int/float tuple" types.
// While PythonOp running requires those inputs be there otherwise kernel execution will fail.
// So during model exporting, we need register those input with this API, then a ref cnt is increased by 1,
Expand Down Expand Up @@ -92,6 +102,9 @@ class OrtTorchFunctionPool final {

std::unordered_map<std::string, PythonObjectPtr> forward_core_pool_;
std::unordered_map<std::string, PythonObjectPtr> backward_core_pool_;
std::unordered_map<std::string, PythonObjectPtr> shape_inference_function_pool_;
std::unordered_map<std::string, PythonObjectPtr> input_alias_function_pool_;

std::unordered_map<std::string, PythonObjectPtr> miscellaneous_const_input_pool_;
std::unordered_map<int64_t, PythonObjectPtr> func_context_pool_;

Expand Down
50 changes: 50 additions & 0 deletions orttraining/orttraining/core/framework/torch/torch_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,54 @@ void TorchProxy::Backward(
returned_ortvalues);
}

void TorchProxy::RunInputAliasFunction(
void* input_alias_function,
const std::string& node_proto_str,
std::vector<int64_t>& fw_output_to_input_alias_map,
std::vector<int64_t>& bw_output_to_input_alias_map) {
PyObject* input_alias_func = reinterpret_cast<PyObject*>(input_alias_function);
ORT_ENFORCE(PyCallable_Check(input_alias_func), "input_alias_func is not callable.");

// All arguments created for Python call will be destroyed along with PythonObjectPtr.
PythonObjectPtr args(Ort_PyTuple_New(1, "input_alias_func_arguments_tuple"), PythonObjectDeleter);
PyObject* node_proto_ptr_arg = PyBytes_FromStringAndSize(node_proto_str.c_str(), node_proto_str.size());
Ort_PyTuple_SetItem_NoIncref(args.get(), 0, node_proto_ptr_arg, "node_proto_ptr_arg");

PythonObjectPtr result_ptr(PyObject_CallObject(input_alias_func, args.get()), PythonObjectDeleter);
if (PyErr_Occurred()) {
PyErr_Print();
ORT_THROW("Python function execution fails with the above information.");
}

bool is_tuple = PyTuple_Check(result_ptr.get());
bool is_list = PyList_Check(result_ptr.get());
ORT_ENFORCE(is_tuple || is_list, "Python function must return a tuple or a list. is_tuple: ",
is_tuple, ", is_list: ", is_list);
Py_ssize_t ret_tuple_size =
is_tuple ? PyTuple_Size(result_ptr.get()) : PyList_Size(result_ptr.get());
ORT_ENFORCE(ret_tuple_size == 2, "Input alias function must return a tuple/list of size 2.");

for (Py_ssize_t tuple_index = 0; tuple_index < ret_tuple_size; ++tuple_index) {
PyObject* alias_map = is_tuple ? PyTuple_GetItem(result_ptr.get(), tuple_index)
: PyList_GetItem(result_ptr.get(), tuple_index);

std::vector<int64_t>& output_to_input_alias_map =
tuple_index == 0 ? fw_output_to_input_alias_map : bw_output_to_input_alias_map;

bool is_elem_tuple = PyTuple_Check(alias_map);
bool is_elem_list = PyList_Check(alias_map);

ORT_ENFORCE(is_elem_tuple || is_elem_list, "Input alias map must be a tuple or a list. is_elem_list: ",
is_elem_list, ", is_elem_tuple: ", is_elem_tuple);
Py_ssize_t output_count = is_elem_tuple ? PyTuple_Size(alias_map) : PyList_Size(alias_map);
for (Py_ssize_t output_index = 0; output_index < output_count; ++output_index) {
PyObject* input_index =
is_elem_tuple ? PyTuple_GetItem(alias_map, output_index) : PyList_GetItem(alias_map, output_index);
ORT_ENFORCE(PyLong_Check(input_index), "Alias input index must be an integer.");
int64_t alias_index_int = PyLong_AsLongLong(input_index);
output_to_input_alias_map.push_back(alias_index_int);
}
}
}

} // namespace onnxruntime::language_interop_ops::torch
31 changes: 31 additions & 0 deletions orttraining/orttraining/core/framework/torch/torch_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
// Licensed under the MIT License.

#pragma once

#include <mutex>
#include <optional>
#include <string>
#include <vector>
#include "orttraining/core/framework/torch/python_common.h"

#ifndef SHARED_PROVIDER
Expand Down Expand Up @@ -61,6 +64,34 @@ class TorchProxy {
const std::string& invoke_id,
std::vector<OrtValue>& return_args);

/**
* @brief Run given function to get output to input reuse map.
*
* @param input_alias_func Python function to run.
* The function should take a serialized PythonOp NodeProto string as input, return a tuple of two lists.
* The signature of the function should be:
* def alias_input(node_proto_str: str):
* fw_alias_map = [1, -1, -1]
* bw_alias_map = [-1, 0]
* return fw_alias_map, bw_alias_map
* @param node_proto_str The serialized PythonOp NodeProto string.
* @param fw_output_to_input_alias_map Used as returned value, return the output to input alias map for forward pass.
* For example, if the inputs of the torch.autograd.Function are [non_tensor_a, tensor_b],
* outputs are [tensor_x, tensor_y, tensor_z], and the alias map is [1, -1, -1], this is explained as:
* tensor_x is reusing the input tensor_b, tensor_y and tensor_z are not reusing any input.
* The value of alias map is 0 based input index. -1 means the output is not reusing any input.
* @param bw_output_to_input_alias_map Used as returned value, return the output to input alias map for backward pass.
* For example, if the inputs of the torch.autograd.Function are [tensor_x_grad, None, None],
* outputs are [None, tensor_b_grad], and the alias map is [-1, 0], this is explained as:
* tensor_b_grad is reusing the input tensor_x_grad.
* The value of alias map is 0 based grad input index. -1 means the output is not reusing any input.
*/
void RunInputAliasFunction(
void* input_alias_func,
const std::string& node_proto_str,
std::vector<int64_t>& fw_output_to_input_alias_map,
std::vector<int64_t>& bw_output_to_input_alias_map);

private:
TorchProxy(){};
~TorchProxy(){};
Expand Down
8 changes: 8 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1848,6 +1848,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetPythonOpGradient) {
"PythonOpGrad requiring gradient output count mismatch.");
attrs.push_back(MakeAttribute("output_tensor_requires_grads", bw_tensor_output_requires_grads));

// Copy bw_tensor_reuse_map attribute from PythonOp to PythonOpGrad if it is present.
auto attr_it = src_attrs.find("bw_tensor_reuse_map");
if (attr_it != src_attrs.end()) {
std::vector<int64_t> tensor_output_to_tensor_input_reuse_map(attr_it->second.ints().begin(),
attr_it->second.ints().end());
attrs.push_back(MakeAttribute("tensor_reuse_map", tensor_output_to_tensor_input_reuse_map));
}

if (src_attrs.find("comment") != src_attrs.end() && utils::HasString(src_attrs.at("comment"))) {
attrs.push_back(MakeAttribute("comment", src_attrs.at("comment").s()));
}
Expand Down
11 changes: 11 additions & 0 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3918,6 +3918,17 @@ Return true if all elements are true and false otherwise.
"- the output 2 reuses the input 0.",
AttributeProto::INTS,
false)
.Attr(
"bw_tensor_reuse_map",
"Used for backward op only."
"A int array indicating whether output at each index is reusing specific input or now."
"If the given index is -1, it means the output is not reusing any input."
"For example, there are 3 inputs (including ctx) and 2 outputs, tensor_reuse_map = [2, 1] means"
"- the output 0 reuses the input 2."
"- the output 1 reuses the input 1."
"Be noted: the input 0 is ctx.",
AttributeProto::INTS,
false)
.Attr(
"training_mode",
"Indicate if the model is exported in training_mode, by default, False.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
#include "core/optimizer/pre_shape_node_elimination.h"
#include "orttraining/core/optimizer/compute_optimizer/padding_elimination.h"
#include "orttraining/core/optimizer/compute_optimizer/sceloss_compute_optimization.h"
#ifdef ENABLE_TRAINING_TORCH_INTEROP
#include "orttraining/core/optimizer/pythonop_rewriter.h"
#endif

namespace onnxruntime {
namespace training {
Expand Down Expand Up @@ -106,6 +109,9 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique<InsertSoftmaxCrossEntropyLossOutput>()));
ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique<LSTMReplacement>()));
ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique<GRUReplacement>()));
#ifdef ENABLE_TRAINING_TORCH_INTEROP
ORT_THROW_IF_ERROR(rule_transformer->Register(std::make_unique<PythonOpRewriter>()));
#endif

// Put ConstantSharing before CommonSubexpressionElimination by intention as it can create more opportunities for
// CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
Expand Down
Loading

0 comments on commit 95307fe

Please sign in to comment.