Skip to content

Commit

Permalink
Add distributed_port for deepspeed.initialize (microsoft#5260)
Browse files Browse the repository at this point in the history
`deepspeed.initialize` does not involve the `distributed_port` argument,
and always uses `TORCH_DISTRIBUTED_DEFAULT_PORT` to initialize the dist
env

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored and dbyoung18 committed Jun 11, 2024
1 parent ddfb3ae commit a28fd12
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import module_inject

from .accelerator import get_accelerator
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
Expand Down Expand Up @@ -71,6 +72,7 @@ def initialize(args=None,
model_parameters: Optional[torch.nn.Module] = None,
training_data: Optional[torch.utils.data.Dataset] = None,
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
mpu=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
Expand All @@ -95,6 +97,8 @@ def initialize(args=None,
lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.
The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods
distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()
Expand Down Expand Up @@ -136,7 +140,9 @@ def initialize(args=None,
global dist
from deepspeed import comm as dist
dist_backend = get_accelerator().communication_backend_name()
dist.init_distributed(dist_backend=dist_backend, dist_init_required=dist_init_required)
dist.init_distributed(dist_backend=dist_backend,
distributed_port=distributed_port,
dist_init_required=dist_init_required)

# Set config using config_params for backwards compat
if config is None and config_params is not None:
Expand Down

0 comments on commit a28fd12

Please sign in to comment.