Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distributed_port for deepspeed.initialize #5260

Merged
merged 4 commits into from
Apr 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import module_inject

from .accelerator import get_accelerator
from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.hybrid_engine import DeepSpeedHybridEngine
Expand Down Expand Up @@ -72,6 +73,7 @@ def initialize(args=None,
model_parameters: Optional[torch.nn.Module] = None,
training_data: Optional[torch.utils.data.Dataset] = None,
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,
mpu=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
Expand All @@ -96,6 +98,8 @@ def initialize(args=None,
lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.
The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods
distributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed training
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()
Expand Down Expand Up @@ -137,7 +141,9 @@ def initialize(args=None,
global dist
from deepspeed import comm as dist
dist_backend = get_accelerator().communication_backend_name()
dist.init_distributed(dist_backend=dist_backend, dist_init_required=dist_init_required)
dist.init_distributed(dist_backend=dist_backend,
distributed_port=distributed_port,
dist_init_required=dist_init_required)

# Set config using config_params for backwards compat
if config is None and config_params is not None:
Expand Down
Loading