Skip to content

Commit

Permalink
Zhijxu/cleanup cached tensors when oom (#19306)
Browse files Browse the repository at this point in the history
in pytorch, when oom happens at bp, user could decrease the batch size
and rerun it without restarting the process.

while in ORT, the intermediate tensors are kept even OOM, so decrease
batch size still fail.


this is torch run, we can see after oom failure, torch will release
tensor before next step

![image](https://github.com/microsoft/onnxruntime/assets/43435212/92b8a2e3-454b-448a-a223-17cb91d463c2)

this is from ort, we can see ort not release its tensors after OOM
failure.

![image](https://github.com/microsoft/onnxruntime/assets/43435212/bb6a3882-8e14-4f37-8079-e7f70fc2546b)

ort with the PR, we can see memory is released, **the 4GB memory is not
own by ort, and will be released by torch at the end**.

![image](https://github.com/microsoft/onnxruntime/assets/43435212/7f39d711-4e36-47d5-aecf-3805433a6d01)
  • Loading branch information
zhijxu-MS authored Feb 21, 2024
1 parent 0c4421c commit 8fadc6c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 12 deletions.
21 changes: 21 additions & 0 deletions onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const {

Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); }

#ifdef ENABLE_TRAINING
void IExecutionFrame::ReleaseAllMLValues() {
for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) {
all_values_[ort_value_idx] = OrtValue();
}
}
#endif

Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx);
Expand Down Expand Up @@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const {
// This method is not thread safe!
// Return S_OK and nullptr if index map to a value that is an unused optional input/output
Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) {
#ifdef ENABLE_TRAINING
try {
auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
return status;
} catch (const std::exception& e) {
LOGS(session_state_.Logger(), WARNING)
<< "Exception caught when allocating memory for ort_value with index: " << ort_value_idx
<< "so clean up all OrtValues";
ReleaseAllMLValues();
return Status(ONNXRUNTIME, FAIL, e.what());
}
#else
return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
#endif
}

void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/framework/execution_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class IExecutionFrame {

const std::unordered_map<int, OrtValue>& initializers);
Status GetOutputs(gsl::span<const int> fetch_mlvalue_idxs, std::vector<OrtValue>& fetches);
// if OOM happens, then release all values, so session can run next batch.
void ReleaseAllMLValues();
#endif

// TO DO: make it thread safe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,20 @@ def backward(ctx, *grad_outputs):

# Run and get results
backward_outputs = C.OrtValueVector()
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
# affect peak memory usage in a subsequent graph run.
del ctx.run_info.state

# Fast version: all backward_outputs are converted first.
# This version only works if backward_outputs is an OrtValueVector.
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)

self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)

return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
try:
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
# affect peak memory usage in a subsequent graph run.

# Fast version: all backward_outputs are converted first.
# This version only works if backward_outputs is an OrtValueVector.
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)

self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)
res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
return res
finally:
del ctx.run_info.state

return _ORTModuleFunction

Expand Down

0 comments on commit 8fadc6c

Please sign in to comment.