diff --git a/submission_runner.py b/submission_runner.py index 5a793dbf1..69d55952a 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -579,9 +579,9 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] # Prevent OOM on librispeech conformer. - # if FLAGS.workload == 'librispeech_conformer': - # os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' - # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if FLAGS.workload == 'librispeech_conformer': + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' # Extend path according to framework. workload_metadata['workload_path'] = os.path.join(