Skip to content

Commit

Permalink
clean ortmodule state when backward failed
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijxu-MS committed Jan 29, 2024
1 parent 82c1cb4 commit b06fbec
Showing 1 changed file with 12 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,21 @@ def backward(ctx, *grad_outputs):

# Run and get results
backward_outputs = C.OrtValueVector()
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
# affect peak memory usage in a subsequent graph run.
del ctx.run_info.state
try:
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
# affect peak memory usage in a subsequent graph run.

# Fast version: all backward_outputs are converted first.
# This version only works if backward_outputs is an OrtValueVector.
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
# Fast version: all backward_outputs are converted first.
# This version only works if backward_outputs is an OrtValueVector.
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)

self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)
self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)
res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
return res
finally:
del ctx.run_info.state

return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)

return _ORTModuleFunction

Expand Down

0 comments on commit b06fbec

Please sign in to comment.