diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index ae7b18840..c89408053 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -141,6 +141,7 @@ def lightning_train( ) train_args[k] = arg_spec + nproc_per_node = trainer.num_devices if nproc_per_node < 0 else nproc_per_node _lightning_train_remote( execution_id, cluster_info=cluster_info,