diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index cc533e549db92..1cc8c73d65cba 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -196,18 +196,21 @@ def backward(ctx, *grad_outputs): # Run and get results backward_outputs = C.OrtValueVector() - self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) - # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not - # affect peak memory usage in a subsequent graph run. - del ctx.run_info.state + try: + self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) + # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not + # affect peak memory usage in a subsequent graph run. - # Fast version: all backward_outputs are converted first. - # This version only works if backward_outputs is an OrtValueVector. - transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) + # Fast version: all backward_outputs are converted first. + # This version only works if backward_outputs is an OrtValueVector. + transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) - self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) + self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) + res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + return res + finally: + del ctx.run_info.state - return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) return _ORTModuleFunction