From 8b754d3a1c7ec1c540604eb08be0334f9faa7750 Mon Sep 17 00:00:00 2001 From: Akhilesh Halageri Date: Tue, 28 Nov 2023 00:24:48 +0000 Subject: [PATCH] fix(traiin): better default for nproc_per_node --- zetta_utils/training/lightning/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/zetta_utils/training/lightning/train.py b/zetta_utils/training/lightning/train.py index 9866eb564..43495a9f0 100644 --- a/zetta_utils/training/lightning/train.py +++ b/zetta_utils/training/lightning/train.py @@ -42,7 +42,7 @@ def lightning_train( val_dataloader: Optional[torch.utils.data.DataLoader] = None, full_state_ckpt_path: str = "last", num_nodes: int = 1, - nproc_per_node: int = 1, + nproc_per_node: int = -1, retry_count: int = 3, local_run: bool = True, follow_logs: bool = True, @@ -71,6 +71,7 @@ def lightning_train( checkpoint for the given experiment will be identified and loaded. :param num_nodes: Number of GPU nodes for distributed training. :param nproc_per_node: Number of GPU workers per node. + Default `-1` means use the same as trainer (`trainer.num_devices`) unless overridden. :param retry_count: Max retry count for the master train job; excludes failures due to pod distruptions. :param local_run: If True run the training locally. @@ -114,6 +115,7 @@ def lightning_train( for _key in ["regime", "trainer", "train_dataloader"]: assert train_spec[_key] is not None, f"{_key} requires builder compatible spec." + nproc_per_node = trainer.num_devices if nproc_per_node < 0 else nproc_per_node _lightning_train_remote( execution_id, cluster_info=cluster_info,