From a28fd12bb095ba566d4c33b8af44a893c767de3b Mon Sep 17 00:00:00 2001 From: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> Date: Wed, 3 Apr 2024 03:33:20 +0800 Subject: [PATCH] Add `distributed_port` for `deepspeed.initialize` (#5260) `deepspeed.initialize` does not involve the `distributed_port` argument, and always uses `TORCH_DISTRIBUTED_DEFAULT_PORT` to initialize the dist env Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index f1d99e1b0e43..fe0043547860 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -26,6 +26,7 @@ from . import module_inject from .accelerator import get_accelerator +from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER from .runtime.hybrid_engine import DeepSpeedHybridEngine @@ -71,6 +72,7 @@ def initialize(args=None, model_parameters: Optional[torch.nn.Module] = None, training_data: Optional[torch.utils.data.Dataset] = None, lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None, + distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT, mpu=None, dist_init_required: Optional[bool] = None, collate_fn=None, @@ -95,6 +97,8 @@ def initialize(args=None, lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object. The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods + distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training + mpu: Optional: A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}() @@ -136,7 +140,9 @@ def initialize(args=None, global dist from deepspeed import comm as dist dist_backend = get_accelerator().communication_backend_name() - dist.init_distributed(dist_backend=dist_backend, dist_init_required=dist_init_required) + dist.init_distributed(dist_backend=dist_backend, + distributed_port=distributed_port, + dist_init_required=dist_init_required) # Set config using config_params for backwards compat if config is None and config_params is not None: