Skip to content

Commit

Permalink
Improve perf for stage3 training (#18099)
Browse files Browse the repository at this point in the history
### Improve perf for stage3 training - first wave

Port existing PythonOp/PythonOpGrad python runner to C++, also introduce
an unsafe run mode (to skip inplace, save for backward, materrialized
grad detection on the fly).

This reduce the overhead from XX~XXX us to X ~ lower end of XX us . In
LLAMA2 7B training with 8x32GV100, we have observed 6.7% gains over
PyTorch. (1.59 v.s. 1.49it/s)

Peak memory also dropped from 31GB to 28GB.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Dec 15, 2023
1 parent cbad4fe commit 5eda79b
Show file tree
Hide file tree
Showing 30 changed files with 1,520 additions and 1,183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ void OrtTorchFunctionPool::RegisterTorchAutogradFunction(
PythonObjectPtr forward(PyObject_GetAttrString(obj, "apply"), PythonObjectDeleter);
PythonObjectPtr backward(PyObject_GetAttrString(obj, "backward"), PythonObjectDeleter);

PythonObjectPtr unsafe_forward(PyObject_GetAttrString(obj, "forward"), PythonObjectDeleter);
ORT_ENFORCE(forward.get(), "apply attribute not found when registering ", key);
ORT_ENFORCE(backward.get(), "backward attribute not found when registering ", key);
ORT_ENFORCE(unsafe_forward.get(), "forward attribute not found when registering ", key);

RegisterEntry(mutex_, key, forward.get(), forward_core_pool_);
RegisterEntry(mutex_, key, backward.get(), backward_core_pool_);
RegisterEntry(mutex_, key, unsafe_forward.get(), unsafe_forward_core_pool_);
}

void OrtTorchFunctionPool::RegisterShapeInferenceFunction(const std::string& key,
Expand All @@ -105,46 +108,27 @@ void OrtTorchFunctionPool::RegisterInputAliasFunction(const std::string& key,
RegisterEntry(mutex_, key, obj, input_alias_function_pool_);
}

static void RegisterEntry(
std::mutex& mutex,
PyObject* obj,
PythonObjectPtr& storage) {
std::lock_guard<std::mutex> lock(mutex);
// Basic checks.
ORT_ENFORCE(obj, "Cannot register NULL PyObject*.");

// Skip registration if storage already stores a Python object.
if (storage.get() != nullptr) {
return;
}

// Own the Python object.
Py_INCREF(obj);
PythonObjectPtr ptr(obj, PythonObjectDeleter);

// If an obj has been registered, this old ownership is automatically released
// after this move-assignment. Then, the "storage" owns the new object.
storage = std::move(ptr);
void OrtTorchFunctionPool::RegisterForwardRunner(size_t function_address) {
void* p_forward_runner_func = reinterpret_cast<void*>(function_address);
forward_runner_ = reinterpret_cast<CustomFunctionRunnerType>(p_forward_runner_func);
}

void OrtTorchFunctionPool::RegisterForwardRunner(PyObject* obj) {
RegisterEntry(mutex_, obj, forward_runner_);
void OrtTorchFunctionPool::RegisterBackwardRunner(size_t function_address) {
void* p_backward_runner_func = reinterpret_cast<void*>(function_address);
backward_runner_ = reinterpret_cast<CustomFunctionRunnerType>(p_backward_runner_func);
}

void OrtTorchFunctionPool::RegisterBackwardRunner(PyObject* obj) {
RegisterEntry(mutex_, obj, backward_runner_);
}
CustomFunctionRunnerType OrtTorchFunctionPool::GetForwardRunner() {
ORT_ENFORCE(forward_runner_,
"Forward runner cannot be NULL. Did you forget to register it by calling RegisterForwardRunner(...)?");

PyObject* OrtTorchFunctionPool::GetForwardRunner() {
std::lock_guard<std::mutex> lock(mutex_);
ORT_ENFORCE(forward_runner_.get(), "Forward runner cannot be NULL. Do you forget register it by calling RegisterForwardRunner(...)?");
return forward_runner_.get();
return forward_runner_;
}

PyObject* OrtTorchFunctionPool::GetBackwardRunner() {
std::lock_guard<std::mutex> lock(mutex_);
ORT_ENFORCE(backward_runner_.get(), "backward runner cannot be NULL. Do you forget register it by calling RegisterBackwardRunner(...)?");
return backward_runner_.get();
CustomFunctionRunnerType OrtTorchFunctionPool::GetBackwardRunner() {
ORT_ENFORCE(backward_runner_,
"backward runner cannot be NULL. Did you forget to register it by calling RegisterBackwardRunner(...)?");
return backward_runner_;
}

PyObject* OrtTorchFunctionPool::GetForwardCore(const std::string& key) {
Expand All @@ -163,6 +147,14 @@ PyObject* OrtTorchFunctionPool::GetBackwardCore(const std::string& key) {
return iter->second.get();
}

PyObject* OrtTorchFunctionPool::GetUnsafeForwardCore(const std::string& key) {
ORT_ENFORCE(!key.empty(), "Cannot be empty string.");
std::lock_guard<std::mutex> lock(mutex_);
auto iter = unsafe_forward_core_pool_.find(key);
ORT_ENFORCE(iter != unsafe_forward_core_pool_.end(), "No unsafe forward registered for ", key);
return iter->second.get();
}

std::optional<PyObject*> OrtTorchFunctionPool::TryGettingShapeInferenceFunction(const std::string& key) {
ORT_ENFORCE(!key.empty(), "Cannot be empty string.");
std::lock_guard<std::mutex> lock(mutex_);
Expand Down Expand Up @@ -201,10 +193,9 @@ int64_t OrtTorchFunctionPool::RegisterContext(PyObject* autograd_context) {
autograd_context, "autograd_context_register");

ORT_ENFORCE(autograd_context, "Cannot register NULL autograd context.");
Py_INCREF(autograd_context);

func_context_pool_.insert({index_, PythonObjectPtr(autograd_context, PythonObjectDeleter)});
// We don't need increase the context refcnt because PyTorch already did it during .apply().

return index_;
}

Expand All @@ -227,14 +218,13 @@ PyObject* OrtTorchFunctionPool::GetContext(int64_t context_index) {
}

void OrtTorchFunctionPool::UnRegisterGlobalFunctions() {
forward_runner_.reset();
backward_runner_.reset();
func_context_pool_.clear();
}

void OrtTorchFunctionPool::UnRegisterModelSpecificFunctions() {
forward_core_pool_.clear();
backward_core_pool_.clear();
unsafe_forward_core_pool_.clear();
shape_inference_function_pool_.clear();
input_alias_function_pool_.clear();
miscellaneous_const_input_pool_.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ namespace onnxruntime {
namespace language_interop_ops {
namespace torch {

typedef std::vector<PyObject*> (*CustomFunctionRunnerType)(const char* func_name_char,
void* callback,
const std::vector<int64_t>& requires_grads,
const std::vector<int64_t>& tensor_type_flags,
const bool is_training_mode,
const std::vector<int64_t>& inplace_map,
const char* kernel_invoke_id_char,
const bool safe_run_mode_enabled,
const std::vector<PyObject*>& tensor_args);

class OrtTorchFunctionPool final {
public:
static OrtTorchFunctionPool& GetInstance() {
Expand All @@ -34,6 +44,9 @@ class OrtTorchFunctionPool final {
// 2. Caller of GetBackwardCore should not decrease the reference count of the returned object.
PyObject* GetBackwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOpGrad.

// Return a borrowed reference to the stored Python function running in safe mode.
PyObject* GetUnsafeForwardCore(const std::string& key); // The "key" is the "name" attribute in PythonOp.

// Shape inference function is used to infer output shape of a PythonOp.
void RegisterShapeInferenceFunction(const std::string& key, PyObject* obj);
// Return a borrowed reference to the stored Python function, if it exists; otherwise, return nullptr.
Expand Down Expand Up @@ -67,15 +80,15 @@ class OrtTorchFunctionPool final {
// ForwardRunner/BackwardRunner are "glue" codes written in Python that interacting
// with C++ kernels during Python function invoking.
// This function creates new ownership to "obj".
void RegisterForwardRunner(PyObject* obj);
void RegisterForwardRunner(size_t function_address);
// This function creates new ownership to "obj".
void RegisterBackwardRunner(PyObject* obj);
// Return a borrowed reference to a Python function, which
void RegisterBackwardRunner(size_t function_address);
// Return a borrowed reference to a c++ function, which
// is responsible for executing autograd.Function.apply.
PyObject* GetForwardRunner();
// Return a borrowed reference to a Python function, which
CustomFunctionRunnerType GetForwardRunner();
// Return a borrowed reference to a c++ function, which
// is responsible for executing autograd.Function.apply.
PyObject* GetBackwardRunner();
CustomFunctionRunnerType GetBackwardRunner();

// The reason we provide this unregister api is:
// A static OrtTorchFunctionPool instance will be destructed after
Expand All @@ -97,11 +110,12 @@ class OrtTorchFunctionPool final {
void UnRegisterGlobalFunctions();
void UnRegisterModelSpecificFunctions();

PythonObjectPtr forward_runner_;
PythonObjectPtr backward_runner_;
CustomFunctionRunnerType forward_runner_;
CustomFunctionRunnerType backward_runner_;

std::unordered_map<std::string, PythonObjectPtr> forward_core_pool_;
std::unordered_map<std::string, PythonObjectPtr> backward_core_pool_;
std::unordered_map<std::string, PythonObjectPtr> unsafe_forward_core_pool_;
std::unordered_map<std::string, PythonObjectPtr> shape_inference_function_pool_;
std::unordered_map<std::string, PythonObjectPtr> input_alias_function_pool_;

Expand Down
Loading

0 comments on commit 5eda79b

Please sign in to comment.