From 471dcab3433f06f7baad575d69677adb6b802c40 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 22 Sep 2023 22:01:35 +0000 Subject: [PATCH] oon observer --- submission_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index e91d8fc83..b0981d941 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -165,7 +165,6 @@ def oom_observer(device, alloc, device_alloc, device_free): snapshot = torch.cuda.memory._snapshot() dump(snapshot, open('oom_snapshot.pickle', 'wb')) -torch._C._cuda_attach_out_of_memory_observer(oom_observer) def _reset_cuda_mem(): if FLAGS.framework == 'pytorch' and torch.cuda.is_available(): @@ -194,6 +193,9 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) + if torch.cuda.is_initialized(): + torch._C._cuda_attach_out_of_memory_observer(oom_observer) + # Workload setup. logging.info('Initializing dataset.') with profiler.profile('Initializing dataset'):