Skip to content

Commit

Permalink
release intermediate tensor if backward failed
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijxu-MS committed Jan 29, 2024
1 parent b06fbec commit c8560e6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
22 changes: 22 additions & 0 deletions onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const {

Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); }

#ifdef ENABLE_TRAINING
Status IExecutionFrame::ReleaseAllMLValues(){

Check warning on line 208 in onnxruntime/core/framework/execution_frame.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/framework/execution_frame.cc#L208

Missing space before { [whitespace/braces] [5]
Raw output
onnxruntime/core/framework/execution_frame.cc:208:  Missing space before {  [whitespace/braces] [5]
for (uint ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) {
all_values_[ort_value_idx] = OrtValue();
}
return Status::OK();
}
#endif

Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx);
Expand Down Expand Up @@ -831,7 +840,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const {
// This method is not thread safe!
// Return S_OK and nullptr if index map to a value that is an unused optional input/output
Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) {
#ifndef ENABLE_TRAINING
return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
#else
try {
auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
return status;
} catch (const std::exception& e) {
LOGS(session_state_.Logger(), WARNING)
<< "Exception caught when allocating memory for ort_value with index: " << ort_value_idx
<< "so clean up all_values_";
static_cast<void>(ReleaseAllMLValues());
return Status(ONNXRUNTIME, FAIL, e.what());
}
#endif
}

void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/framework/execution_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class IExecutionFrame {

const std::unordered_map<int, OrtValue>& initializers);
Status GetOutputs(gsl::span<const int> fetch_mlvalue_idxs, std::vector<OrtValue>& fetches);
// if OOM happens, then release all values, so session can run next batch.
Status ReleaseAllMLValues();
#endif

// TO DO: make it thread safe
Expand Down

0 comments on commit c8560e6

Please sign in to comment.