Skip to content

Commit

Permalink
PT distributed training, better DDP wrapping
Browse files Browse the repository at this point in the history
Fix #1451
  • Loading branch information
albertz committed Nov 28, 2023
1 parent 200b7a4 commit 90c7548
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 76 deletions.
67 changes: 0 additions & 67 deletions returnn/torch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,70 +122,3 @@ def _find_tensors(obj):
if isinstance(obj, dict):
return itertools.chain(*map(_find_tensors, obj.values()))
return []


@contextmanager
def ddp_train_forward_ctx(pt_model: DistributedDataParallel):
"""
the original (unwrapped) module is passed to the train step, therefore here we set up the right context
as what DistributedDataParallel.forward does internally
"""
if torch.is_grad_enabled() and pt_model.require_backward_grad_sync:
assert pt_model.logger is not None
pt_model.logger.set_runtime_stats_and_log()
pt_model.num_iterations += 1
pt_model.reducer.prepare_for_forward()

with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
if torch.is_grad_enabled() and pt_model.require_backward_grad_sync:
assert pt_model.logger is not None
pt_model.logger.set_runtime_stats_and_log()
pt_model.num_iterations += 1
pt_model.reducer.prepare_for_forward()

work = Join.notify_join_context(pt_model)
if work:
# noinspection PyProtectedMember
pt_model.reducer._set_forward_pass_work_handle(work, pt_model._divide_by_initial_world_size)

# noinspection PyProtectedMember
if torch.is_grad_enabled() and pt_model.reducer._rebuild_buckets():
pt_model._has_rebuilt_buckets = True

# noinspection PyProtectedMember
if pt_model._check_sync_bufs_pre_fwd():
# noinspection PyProtectedMember
pt_model._sync_buffers()

# noinspection PyProtectedMember
if pt_model._join_config.enable:
# Notify joined ranks whether they should sync in backwards pass or not.
# noinspection PyProtectedMember
pt_model._check_global_requires_backward_grad_sync(is_joined_rank=False)

# noinspection PyProtectedMember
with pt_model._inside_ddp_forward():
yield

# noinspection PyProtectedMember
if pt_model._check_sync_bufs_post_fwd():
# noinspection PyProtectedMember
pt_model._sync_buffers()

if torch.is_grad_enabled() and pt_model.require_backward_grad_sync:
pt_model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if pt_model.find_unused_parameters and not pt_model.static_graph:
# Do not need to populate this for static graph.
train_ctx = rf.get_run_ctx()
loss = list(train_ctx.losses.values())[0].loss.raw_tensor
# noinspection PyProtectedMember
pt_model.reducer.prepare_for_backward(list(_find_tensors(loss)))
else:
pt_model.reducer.prepare_for_backward([])
else:
pt_model.require_forward_param_sync = False
51 changes: 42 additions & 9 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def init_train_from_config(

# wrap the model use torch distributed class
self._ddp_pt_model = self._torch_distributed_class(
self._pt_model, device_ids=get_device_ids(), **self._torch_distributed_options
module=_WrappedModuleRunStep(module=self._pt_model, engine=self),
device_ids=get_device_ids(),
**self._torch_distributed_options,
)
self._updater = Updater(
config=self.config, network=self._pt_model, device=self._device, initial_learning_rate=self.learning_rate
Expand Down Expand Up @@ -542,10 +544,17 @@ def _create_data_loader(self, dataset: Dataset) -> DataLoader:
**loader_opts,
)

def _run_step(self, extern_data: TensorDict, *, train_flag: bool = False, train_func: bool):
def _run_step(
self, extern_data: TensorDict, *, train_flag: bool = False, train_func: bool, _inside_wrapped: bool = False
):
"""
:param extern_data: model inputs for the step
:return: Nothing, all outputs are written to the run context (:func:`rf.get_run_ctx`).
"""
if self._ddp_pt_model is not None and not _inside_wrapped:
self._ddp_pt_model(extern_data=extern_data, train_flag=train_flag, train_func=train_func)
return

if train_func:
assert self._train_step_func is not None
rf.init_train_step_run_ctx(train_flag=train_flag, step=self.global_train_step)
Expand All @@ -555,15 +564,9 @@ def _run_step(self, extern_data: TensorDict, *, train_flag: bool = False, train_
expected_outputs=self._forward_step_expected_outputs, step=self.global_train_step
)

from returnn.torch.distributed import ddp_train_forward_ctx

with autocast(
device_type=self._device.split(":")[0], dtype=self._autocast_dtype
) if self._use_autocast else nullcontext(), ddp_train_forward_ctx(pt_model=self._ddp_pt_model) if isinstance(
self._ddp_pt_model, DistributedDataParallel
) else nullcontext(), rf.set_default_device_ctx(
self._device
):
) if self._use_autocast else nullcontext(), rf.set_default_device_ctx(self._device):
sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
if train_func:
self._train_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw)
Expand Down Expand Up @@ -1038,3 +1041,33 @@ def _data_loader_worker_init_func(worker_id: int):
if sys.platform == "linux":
with open("/proc/self/comm", "w") as f:
f.write(f"TDL worker {worker_id}")


class _WrappedModuleRunStep(torch.nn.Module):
"""
Wraps any Torch module (pure or RF),
and the `forward` function calls the run step function (train_step or forward_step)
and returns all produced raw tensors via the run context (losses or outputs) (:func:`rf.get_run_ctx`).
This is useful to use the API of DistributedDataParallel and potentially other PyTorch modules.
"""

def __init__(self, *, module: torch.nn.Module, engine: Engine):
super().__init__()
self.module = module
self.engine = engine

def forward(self, *args, **kwargs):
# noinspection PyProtectedMember
self.engine._run_step(*args, **kwargs, _inside_wrapped=True)

# Note that we don't directly use the returned raw values here,
# but the PyTorch API might,
# e.g. DistributedDataParallel checks it to collect all gradients.
# We will use rf.get_run_ctx() later again in the engine to access these values.
res = {}
ctx = rf.get_run_ctx()
for name, out in ctx.outputs.data.items():
res["output/" + name] = out.raw_tensor
for name, loss in ctx.losses.items():
res["loss/" + name] = loss.loss.raw_tensor
return res

0 comments on commit 90c7548

Please sign in to comment.