Skip to content

Commit

Permalink
c++ custom function runner
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Oct 22, 2023
1 parent 0987d9e commit 3efef96
Show file tree
Hide file tree
Showing 12 changed files with 1,200 additions and 1,303 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,46 +105,25 @@ void OrtTorchFunctionPool::RegisterInputAliasFunction(const std::string& key,
RegisterEntry(mutex_, key, obj, input_alias_function_pool_);
}

static void RegisterEntry(
std::mutex& mutex,
PyObject* obj,
PythonObjectPtr& storage) {
std::lock_guard<std::mutex> lock(mutex);
// Basic checks.
ORT_ENFORCE(obj, "Cannot register NULL PyObject*.");

// Skip registration if storage already stores a Python object.
if (storage.get() != nullptr) {
return;
}

// Own the Python object.
Py_INCREF(obj);
PythonObjectPtr ptr(obj, PythonObjectDeleter);

// If an obj has been registered, this old ownership is automatically released
// after this move-assignment. Then, the "storage" owns the new object.
storage = std::move(ptr);
void OrtTorchFunctionPool::RegisterForwardRunner(size_t function_address) {
void* p_forward_runner_func = reinterpret_cast<void*>(function_address);
forward_runner_ = reinterpret_cast<CustomFunctionRunnerType>(p_forward_runner_func);
}

void OrtTorchFunctionPool::RegisterForwardRunner(PyObject* obj) {
RegisterEntry(mutex_, obj, forward_runner_);
void OrtTorchFunctionPool::RegisterBackwardRunner(size_t function_address) {
void* p_backward_runner_func = reinterpret_cast<void*>(function_address);
backward_runner_ = reinterpret_cast<CustomFunctionRunnerType>(p_backward_runner_func);
}

void OrtTorchFunctionPool::RegisterBackwardRunner(PyObject* obj) {
RegisterEntry(mutex_, obj, backward_runner_);
}
CustomFunctionRunnerType OrtTorchFunctionPool::GetForwardRunner() {
ORT_ENFORCE(forward_runner_, "Forward runner cannot be NULL. Do you forget register it by calling RegisterForwardRunner(...)?");

PyObject* OrtTorchFunctionPool::GetForwardRunner() {
std::lock_guard<std::mutex> lock(mutex_);
ORT_ENFORCE(forward_runner_.get(), "Forward runner cannot be NULL. Do you forget register it by calling RegisterForwardRunner(...)?");
return forward_runner_.get();
return forward_runner_;
}

PyObject* OrtTorchFunctionPool::GetBackwardRunner() {
std::lock_guard<std::mutex> lock(mutex_);
ORT_ENFORCE(backward_runner_.get(), "backward runner cannot be NULL. Do you forget register it by calling RegisterBackwardRunner(...)?");
return backward_runner_.get();
CustomFunctionRunnerType OrtTorchFunctionPool::GetBackwardRunner() {
ORT_ENFORCE(backward_runner_, "backward runner cannot be NULL. Do you forget register it by calling RegisterBackwardRunner(...)?");
return backward_runner_;
}

PyObject* OrtTorchFunctionPool::GetForwardCore(const std::string& key) {
Expand Down Expand Up @@ -227,8 +206,6 @@ PyObject* OrtTorchFunctionPool::GetContext(int64_t context_index) {
}

void OrtTorchFunctionPool::UnRegisterGlobalFunctions() {
forward_runner_.reset();
backward_runner_.reset();
func_context_pool_.clear();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ namespace onnxruntime {
namespace language_interop_ops {
namespace torch {

typedef std::vector<PyObject*> (*CustomFunctionRunnerType)(const char* func_name_char,
void* callback,
const std::vector<int64_t>& requires_grads,
const std::vector<int64_t>& tensor_type_flags,
const bool is_training_mode,
const std::vector<int64_t>& inplace_map,
const char* kernel_invoke_id_char,
const std::vector<PyObject*>& tensor_args);

class OrtTorchFunctionPool final {
public:
static OrtTorchFunctionPool& GetInstance() {
Expand Down Expand Up @@ -67,15 +76,15 @@ class OrtTorchFunctionPool final {
// ForwardRunner/BackwardRunner are "glue" codes written in Python that interacting
// with C++ kernels during Python function invoking.
// This function creates new ownership to "obj".
void RegisterForwardRunner(PyObject* obj);
void RegisterForwardRunner(size_t function_address);
// This function creates new ownership to "obj".
void RegisterBackwardRunner(PyObject* obj);
// Return a borrowed reference to a Python function, which
void RegisterBackwardRunner(size_t function_address);
// Return a borrowed reference to a c++ function, which
// is responsible for executing autograd.Function.apply.
PyObject* GetForwardRunner();
// Return a borrowed reference to a Python function, which
CustomFunctionRunnerType GetForwardRunner();
// Return a borrowed reference to a c++ function, which
// is responsible for executing autograd.Function.apply.
PyObject* GetBackwardRunner();
CustomFunctionRunnerType GetBackwardRunner();

// The reason we provide this unregister api is:
// A static OrtTorchFunctionPool instance will be destructed after
Expand All @@ -97,8 +106,8 @@ class OrtTorchFunctionPool final {
void UnRegisterGlobalFunctions();
void UnRegisterModelSpecificFunctions();

PythonObjectPtr forward_runner_;
PythonObjectPtr backward_runner_;
CustomFunctionRunnerType forward_runner_;
CustomFunctionRunnerType backward_runner_;

std::unordered_map<std::string, PythonObjectPtr> forward_core_pool_;
std::unordered_map<std::string, PythonObjectPtr> backward_core_pool_;
Expand Down
Loading

0 comments on commit 3efef96

Please sign in to comment.