diff --git a/submission_runner.py b/submission_runner.py index e91d8fc83..b0981d941 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -165,7 +165,6 @@ def oom_observer(device, alloc, device_alloc, device_free): snapshot = torch.cuda.memory._snapshot() dump(snapshot, open('oom_snapshot.pickle', 'wb')) -torch._C._cuda_attach_out_of_memory_observer(oom_observer) def _reset_cuda_mem(): if FLAGS.framework == 'pytorch' and torch.cuda.is_available(): @@ -194,6 +193,9 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) + if torch.cuda.is_initialized(): + torch._C._cuda_attach_out_of_memory_observer(oom_observer) + # Workload setup. logging.info('Initializing dataset.') with profiler.profile('Initializing dataset'):