diff --git a/submission_runner.py b/submission_runner.py index 9ca0dab8a..ae4360a6f 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -222,24 +222,22 @@ def train_once( if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ 'librispeech_conformer', - 'librispeech_conformer_gelu', - 'librispeech_conformer_layernorm', - 'librispeech_conformer_attention_temperature', 'ogbg', 'criteo1tb', 'imagenet_vit', ] eager_backend_workloads = ['librispeech_deepspeech'] aot_eager_backend_workloads = [] - if FLAGS.workload in compile_error_workloads: + base_workload = workloads.get_base_workload_name(FLAGS.workload) + if base_workload in compile_error_workloads: logging.warning( 'These workloads cannot be fully compiled under current ' 'PyTorch version. Proceeding without `torch.compile`.') - elif FLAGS.workload in eager_backend_workloads: + elif base_workload in eager_backend_workloads: logging.warning( 'These workloads cannot be fully compiled under current ' 'PyTorch version. Proceeding with `backend=eager`.') model_params = torch.compile(model_params, backend='eager') - elif FLAGS.workload in aot_eager_backend_workloads: + elif base_workload in aot_eager_backend_workloads: logging.warning( 'These workloads cannot be fully compiled under current ' 'PyTorch version. Proceeding with `backend=aot_eager`.') @@ -617,7 +615,8 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] # Prevent OOM on librispeech conformer. - if FLAGS.workload == 'librispeech_conformer': + base_workload = workloads.get_base_workload_name(FLAGS.workload) + if base_workload == 'librispeech_conformer': os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' if FLAGS.set_pytorch_max_split_size: